mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 05:04:47 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
parent
3c6b831c26
commit
079bf3cb26
22
.flake8
22
.flake8
@ -1,22 +0,0 @@
|
||||
[flake8]
|
||||
ignore =
|
||||
;W503 line break before binary operator
|
||||
W503,
|
||||
;E203 whitespace before ':'
|
||||
E203,
|
||||
|
||||
; exclude file
|
||||
exclude =
|
||||
.tox,
|
||||
.git,
|
||||
__pycache__,
|
||||
build,
|
||||
dist,
|
||||
*.pyc,
|
||||
*.egg-info,
|
||||
.cache,
|
||||
.eggs
|
||||
|
||||
max-line-length = 120
|
||||
|
||||
per-file-ignores = __init__.py:F401
|
12
.github/workflows/scripts/check_doc_i18n.py
vendored
12
.github/workflows/scripts/check_doc_i18n.py
vendored
@ -22,13 +22,13 @@ def compare_dirs(dir1, dir2):
|
||||
|
||||
# If the corresponding item doesn't exist in the second directory, the directories are different
|
||||
if not os.path.exists(item_path2):
|
||||
print(f'Found mismatch: {item_path1}, {item_path2}')
|
||||
print(f"Found mismatch: {item_path1}, {item_path2}")
|
||||
return False
|
||||
|
||||
# If the corresponding item is a directory, we compare the two directories recursively
|
||||
if os.path.isdir(item_path1) and os.path.isdir(item_path2):
|
||||
if not compare_dirs(item_path1, item_path2):
|
||||
print(f'Found mismatch: {item_path1}, {item_path2}')
|
||||
print(f"Found mismatch: {item_path1}, {item_path2}")
|
||||
return False
|
||||
|
||||
# both are files
|
||||
@ -37,16 +37,16 @@ def compare_dirs(dir1, dir2):
|
||||
|
||||
# If the corresponding item is not a file or a directory, the directories are different
|
||||
else:
|
||||
print(f'Found mismatch: {item_path1}, {item_path2}')
|
||||
print(f"Found mismatch: {item_path1}, {item_path2}")
|
||||
return False
|
||||
|
||||
# If all items are the same, the directories are the same
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-d', '--directory', help="The directory where the multi-language source files are kept.")
|
||||
parser.add_argument("-d", "--directory", help="The directory where the multi-language source files are kept.")
|
||||
args = parser.parse_args()
|
||||
|
||||
i18n_folders = os.listdir(args.directory)
|
||||
@ -56,7 +56,7 @@ if __name__ == '__main__':
|
||||
for i in range(1, len(i18n_folders)):
|
||||
dir1 = i18n_folders[0]
|
||||
dir2 = i18n_folders[i]
|
||||
print(f'comparing {dir1} vs {dir2}')
|
||||
print(f"comparing {dir1} vs {dir2}")
|
||||
match = compare_dirs(i18n_folders[0], i18n_folders[i])
|
||||
|
||||
if not match:
|
||||
|
@ -4,7 +4,7 @@ import os
|
||||
|
||||
def check_inputs(input_list):
|
||||
for path in input_list:
|
||||
real_path = os.path.join('examples', path)
|
||||
real_path = os.path.join("examples", path)
|
||||
if not os.path.exists(real_path):
|
||||
return False
|
||||
return True
|
||||
@ -12,16 +12,16 @@ def check_inputs(input_list):
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-f', '--fileNameList', type=str, help="List of file names")
|
||||
parser.add_argument("-f", "--fileNameList", type=str, help="List of file names")
|
||||
args = parser.parse_args()
|
||||
name_list = args.fileNameList.split(",")
|
||||
is_correct = check_inputs(name_list)
|
||||
|
||||
if is_correct:
|
||||
print('success')
|
||||
print("success")
|
||||
else:
|
||||
print('failure')
|
||||
print("failure")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -17,21 +17,21 @@ def show_files(path, all_files):
|
||||
|
||||
|
||||
def join(input_list, sep=None):
|
||||
return (sep or ' ').join(input_list)
|
||||
return (sep or " ").join(input_list)
|
||||
|
||||
|
||||
def main():
|
||||
contents = show_files('examples/', [])
|
||||
contents = show_files("examples/", [])
|
||||
all_loc = []
|
||||
for file_loc in contents:
|
||||
split_loc = file_loc.split('/')
|
||||
split_loc = file_loc.split("/")
|
||||
# must have two sub-folder levels after examples folder, such as examples/images/vit is acceptable, examples/images/README.md is not, examples/requirements.txt is not.
|
||||
if len(split_loc) >= 4:
|
||||
re_loc = '/'.join(split_loc[1:3])
|
||||
re_loc = "/".join(split_loc[1:3])
|
||||
if re_loc not in all_loc:
|
||||
all_loc.append(re_loc)
|
||||
print(all_loc)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -3,7 +3,7 @@ import argparse
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-f', '--fileNameList', type=str, help="The list of changed files")
|
||||
parser.add_argument("-f", "--fileNameList", type=str, help="The list of changed files")
|
||||
args = parser.parse_args()
|
||||
name_list = args.fileNameList.split(":")
|
||||
folder_need_check = set()
|
||||
@ -15,10 +15,10 @@ def main():
|
||||
# - application
|
||||
# - file
|
||||
if loc.split("/")[0] == "examples" and len(loc.split("/")) >= 4:
|
||||
folder_need_check.add('/'.join(loc.split("/")[1:3]))
|
||||
folder_need_check.add("/".join(loc.split("/")[1:3]))
|
||||
# Output the result using print. Then the shell can get the values.
|
||||
print(list(folder_need_check))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -74,16 +74,16 @@ def get_organization_repositories(github_token, organization_name) -> List[str]:
|
||||
|
||||
# prepare header
|
||||
headers = {
|
||||
'Authorization': f'Bearer {github_token}',
|
||||
'Accept': 'application/vnd.github+json',
|
||||
'X-GitHub-Api-Version': '2022-11-28'
|
||||
"Authorization": f"Bearer {github_token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
}
|
||||
|
||||
res = requests.get(url, headers=headers).json()
|
||||
repo_list = []
|
||||
|
||||
for item in res:
|
||||
repo_list.append(item['name'])
|
||||
repo_list.append(item["name"])
|
||||
return repo_list
|
||||
|
||||
|
||||
@ -97,9 +97,9 @@ def get_issue_pull_request_comments(github_token: str, org_name: str, repo_name:
|
||||
"""
|
||||
# prepare header
|
||||
headers = {
|
||||
'Authorization': f'Bearer {github_token}',
|
||||
'Accept': 'application/vnd.github+json',
|
||||
'X-GitHub-Api-Version': '2022-11-28'
|
||||
"Authorization": f"Bearer {github_token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
}
|
||||
|
||||
user_engagement_count = {}
|
||||
@ -107,28 +107,28 @@ def get_issue_pull_request_comments(github_token: str, org_name: str, repo_name:
|
||||
# do pagination to the API
|
||||
page = 1
|
||||
while True:
|
||||
comment_api = f'https://api.github.com/repos/{org_name}/{repo_name}/issues/comments?since={since}&page={page}'
|
||||
comment_api = f"https://api.github.com/repos/{org_name}/{repo_name}/issues/comments?since={since}&page={page}"
|
||||
comment_response = requests.get(comment_api, headers=headers).json()
|
||||
|
||||
if len(comment_response) == 0:
|
||||
break
|
||||
else:
|
||||
for item in comment_response:
|
||||
comment_author_relationship = item['author_association']
|
||||
if comment_author_relationship != 'MEMBER':
|
||||
comment_author_relationship = item["author_association"]
|
||||
if comment_author_relationship != "MEMBER":
|
||||
# if the comment is not made by our member
|
||||
# we don't count this comment towards user engagement
|
||||
continue
|
||||
|
||||
issue_id = item['issue_url'].split('/')[-1]
|
||||
issue_api = f'https://api.github.com/repos/{org_name}/{repo_name}/issues/{issue_id}'
|
||||
issue_id = item["issue_url"].split("/")[-1]
|
||||
issue_api = f"https://api.github.com/repos/{org_name}/{repo_name}/issues/{issue_id}"
|
||||
issue_response = requests.get(issue_api, headers=headers).json()
|
||||
issue_author_relationship = issue_response['author_association']
|
||||
issue_author_relationship = issue_response["author_association"]
|
||||
|
||||
if issue_author_relationship != 'MEMBER':
|
||||
if issue_author_relationship != "MEMBER":
|
||||
# this means that the issue/PR is not created by our own people
|
||||
# any comments in this issue/PR by our member will be counted towards the leaderboard
|
||||
member_name = item['user']['login']
|
||||
member_name = item["user"]["login"]
|
||||
|
||||
if member_name in user_engagement_count:
|
||||
user_engagement_count[member_name] += 1
|
||||
@ -153,7 +153,7 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
|
||||
if cursor is None:
|
||||
offset_str = ""
|
||||
else:
|
||||
offset_str = f", after: \"{cursor}\""
|
||||
offset_str = f', after: "{cursor}"'
|
||||
query = f"""
|
||||
{{
|
||||
repository(owner: "{org_name}", name: "{repo_name}"){{
|
||||
@ -182,7 +182,7 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
|
||||
if cursor is None:
|
||||
offset_str = ""
|
||||
else:
|
||||
offset_str = f", before: \"{cursor}\""
|
||||
offset_str = f', before: "{cursor}"'
|
||||
query = f"""
|
||||
{{
|
||||
repository(owner: "{org_name}", name: "{repo_name}"){{
|
||||
@ -220,8 +220,8 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
|
||||
# a utility function to make call to Github GraphQL API
|
||||
def _call_graphql_api(query):
|
||||
headers = {"Authorization": f"Bearer {github_token}"}
|
||||
json_data = {'query': query}
|
||||
response = requests.post('https://api.github.com/graphql', json=json_data, headers=headers)
|
||||
json_data = {"query": query}
|
||||
response = requests.post("https://api.github.com/graphql", json=json_data, headers=headers)
|
||||
data = response.json()
|
||||
return data
|
||||
|
||||
@ -234,21 +234,21 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
|
||||
data = _call_graphql_api(query)
|
||||
found_discussion_out_of_time_range = False
|
||||
|
||||
edges = data['data']['repository']['discussions']['edges']
|
||||
edges = data["data"]["repository"]["discussions"]["edges"]
|
||||
if len(edges) == 0:
|
||||
break
|
||||
else:
|
||||
# keep the discussion whose author is not a member
|
||||
for edge in edges:
|
||||
# print the discussion title
|
||||
discussion = edge['node']
|
||||
discussion_updated_at = str2datetime(discussion['updatedAt'])
|
||||
discussion = edge["node"]
|
||||
discussion_updated_at = str2datetime(discussion["updatedAt"])
|
||||
|
||||
# check if the updatedAt is within the last 7 days
|
||||
# if yes, add it to discussion_numbers
|
||||
if discussion_updated_at > since:
|
||||
if discussion['authorAssociation'] != 'MEMBER':
|
||||
discussion_numbers.append(discussion['number'])
|
||||
if discussion["authorAssociation"] != "MEMBER":
|
||||
discussion_numbers.append(discussion["number"])
|
||||
else:
|
||||
found_discussion_out_of_time_range = True
|
||||
|
||||
@ -256,7 +256,7 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
|
||||
break
|
||||
else:
|
||||
# update cursor
|
||||
cursor = edges[-1]['cursor']
|
||||
cursor = edges[-1]["cursor"]
|
||||
|
||||
# get the discussion comments and replies made by our member
|
||||
user_engagement_count = {}
|
||||
@ -269,42 +269,42 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
|
||||
data = _call_graphql_api(query)
|
||||
|
||||
# get the comments
|
||||
edges = data['data']['repository']['discussion']['comments']['edges']
|
||||
edges = data["data"]["repository"]["discussion"]["comments"]["edges"]
|
||||
|
||||
# update the cursor
|
||||
if len(edges) == 0:
|
||||
break
|
||||
else:
|
||||
# update cursor for pagination
|
||||
cursor = edges[-1]['cursor']
|
||||
cursor = edges[-1]["cursor"]
|
||||
|
||||
for edge in edges:
|
||||
comment = edge['node']
|
||||
if comment['authorAssociation'] == 'MEMBER':
|
||||
comment = edge["node"]
|
||||
if comment["authorAssociation"] == "MEMBER":
|
||||
# check if the updatedAt is within the last 7 days
|
||||
# if yes, add it to user_engagement_count
|
||||
comment_updated_at = datetime.strptime(comment['updatedAt'], "%Y-%m-%dT%H:%M:%SZ")
|
||||
comment_updated_at = datetime.strptime(comment["updatedAt"], "%Y-%m-%dT%H:%M:%SZ")
|
||||
if comment_updated_at > since:
|
||||
member_name = comment['author']['login']
|
||||
member_name = comment["author"]["login"]
|
||||
if member_name in user_engagement_count:
|
||||
user_engagement_count[member_name] += 1
|
||||
else:
|
||||
user_engagement_count[member_name] = 1
|
||||
|
||||
# get the replies
|
||||
reply_edges = comment['replies']['edges']
|
||||
reply_edges = comment["replies"]["edges"]
|
||||
if len(reply_edges) == 0:
|
||||
continue
|
||||
else:
|
||||
for reply_edge in reply_edges:
|
||||
reply = reply_edge['node']
|
||||
if reply['authorAssociation'] == 'MEMBER':
|
||||
reply = reply_edge["node"]
|
||||
if reply["authorAssociation"] == "MEMBER":
|
||||
# check if the updatedAt is within the last 7 days
|
||||
# if yes, add it to discussion_numbers
|
||||
|
||||
reply_updated_at = datetime.strptime(reply['updatedAt'], "%Y-%m-%dT%H:%M:%SZ")
|
||||
reply_updated_at = datetime.strptime(reply["updatedAt"], "%Y-%m-%dT%H:%M:%SZ")
|
||||
if reply_updated_at > since:
|
||||
member_name = reply['author']['login']
|
||||
member_name = reply["author"]["login"]
|
||||
if member_name in user_engagement_count:
|
||||
user_engagement_count[member_name] += 1
|
||||
else:
|
||||
@ -312,7 +312,9 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
|
||||
return user_engagement_count
|
||||
|
||||
|
||||
def generate_user_engagement_leaderboard_image(github_token: str, org_name: str, repo_list: List[str], output_path: str) -> bool:
|
||||
def generate_user_engagement_leaderboard_image(
|
||||
github_token: str, org_name: str, repo_list: List[str], output_path: str
|
||||
) -> bool:
|
||||
"""
|
||||
Generate the user engagement leaderboard image for stats within the last 7 days
|
||||
|
||||
@ -335,11 +337,14 @@ def generate_user_engagement_leaderboard_image(github_token: str, org_name: str,
|
||||
else:
|
||||
total_engagement_count[name] = count
|
||||
|
||||
|
||||
for repo_name in repo_list:
|
||||
print(f"Fetching user engagement count for {repo_name}/{repo_name}")
|
||||
issue_pr_engagement_count = get_issue_pull_request_comments(github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime_str)
|
||||
discussion_engagement_count = get_discussion_comments(github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime)
|
||||
issue_pr_engagement_count = get_issue_pull_request_comments(
|
||||
github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime_str
|
||||
)
|
||||
discussion_engagement_count = get_discussion_comments(
|
||||
github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime
|
||||
)
|
||||
|
||||
# update the total engagement count
|
||||
_update_count(issue_pr_engagement_count)
|
||||
@ -363,7 +368,7 @@ def generate_user_engagement_leaderboard_image(github_token: str, org_name: str,
|
||||
# plot the leaderboard
|
||||
xlabel = f"Number of Comments made (since {start_datetime_str})"
|
||||
ylabel = "Member"
|
||||
title = 'Active User Engagement Leaderboard'
|
||||
title = "Active User Engagement Leaderboard"
|
||||
plot_bar_chart(x, y, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path)
|
||||
return True
|
||||
else:
|
||||
@ -380,16 +385,16 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou
|
||||
"""
|
||||
# request to the Github API to get the users who have contributed in the last 7 days
|
||||
headers = {
|
||||
'Authorization': f'Bearer {github_token}',
|
||||
'Accept': 'application/vnd.github+json',
|
||||
'X-GitHub-Api-Version': '2022-11-28'
|
||||
"Authorization": f"Bearer {github_token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
}
|
||||
|
||||
counter = Counter()
|
||||
start_datetime = get_utc_time_one_week_ago()
|
||||
|
||||
def _get_url(org_name, repo_name, page):
|
||||
return f'https://api.github.com/repos/{org_name}/{repo_name}/pulls?per_page=50&page={page}&state=closed'
|
||||
return f"https://api.github.com/repos/{org_name}/{repo_name}/pulls?per_page=50&page={page}&state=closed"
|
||||
|
||||
def _iterate_by_page(org_name, repo_name):
|
||||
page = 1
|
||||
@ -415,8 +420,8 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou
|
||||
|
||||
# count the pull request and author from response
|
||||
for pr_data in response:
|
||||
merged_at = pr_data['merged_at']
|
||||
author = pr_data['user']['login']
|
||||
merged_at = pr_data["merged_at"]
|
||||
author = pr_data["user"]["login"]
|
||||
|
||||
if merged_at is None:
|
||||
continue
|
||||
@ -439,7 +444,7 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou
|
||||
_iterate_by_page(org_name, repo_name)
|
||||
|
||||
# convert unix timestamp to Beijing datetime
|
||||
bj_start_datetime = datetime.fromtimestamp(start_datetime.timestamp(), tz=pytz.timezone('Asia/Shanghai'))
|
||||
bj_start_datetime = datetime.fromtimestamp(start_datetime.timestamp(), tz=pytz.timezone("Asia/Shanghai"))
|
||||
bj_start_datetime_str = datetime2str(bj_start_datetime)
|
||||
|
||||
contribution_list = counter.to_sorted_list()
|
||||
@ -452,7 +457,7 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou
|
||||
if len(author_list) > 0:
|
||||
xlabel = f"Number of Pull Requests (since {bj_start_datetime_str})"
|
||||
ylabel = "Contributor"
|
||||
title = 'Active Contributor Leaderboard'
|
||||
title = "Active Contributor Leaderboard"
|
||||
plot_bar_chart(num_commit_list, author_list, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path)
|
||||
return True
|
||||
else:
|
||||
@ -468,14 +473,14 @@ def upload_image_to_lark(lark_tenant_token: str, image_path: str) -> str:
|
||||
image_path (str): the path to the image to be uploaded
|
||||
"""
|
||||
url = "https://open.feishu.cn/open-apis/im/v1/images"
|
||||
form = {'image_type': 'message', 'image': (open(image_path, 'rb'))} # 需要替换具体的path
|
||||
form = {"image_type": "message", "image": (open(image_path, "rb"))} # 需要替换具体的path
|
||||
multi_form = MultipartEncoder(form)
|
||||
headers = {
|
||||
'Authorization': f'Bearer {lark_tenant_token}', ## 获取tenant_access_token, 需要替换为实际的token
|
||||
"Authorization": f"Bearer {lark_tenant_token}", ## 获取tenant_access_token, 需要替换为实际的token
|
||||
}
|
||||
headers['Content-Type'] = multi_form.content_type
|
||||
headers["Content-Type"] = multi_form.content_type
|
||||
response = requests.request("POST", url, headers=headers, data=multi_form).json()
|
||||
return response['data']['image_key']
|
||||
return response["data"]["image_key"]
|
||||
|
||||
|
||||
def generate_lark_tenant_access_token(app_id: str, app_secret: str) -> str:
|
||||
@ -486,10 +491,10 @@ def generate_lark_tenant_access_token(app_id: str, app_secret: str) -> str:
|
||||
app_id (str): Lark app id
|
||||
app_secret (str): Lark app secret
|
||||
"""
|
||||
url = 'https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal'
|
||||
data = {'app_id': app_id, 'app_secret': app_secret}
|
||||
url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal"
|
||||
data = {"app_id": app_id, "app_secret": app_secret}
|
||||
response = requests.post(url, json=data).json()
|
||||
return response['tenant_access_token']
|
||||
return response["tenant_access_token"]
|
||||
|
||||
|
||||
def send_image_to_lark(image_key: str, webhook_url: str) -> None:
|
||||
@ -516,10 +521,10 @@ def send_message_to_lark(message: str, webhook_url: str):
|
||||
requests.post(webhook_url, json=data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
GITHUB_TOKEN = os.environ['GITHUB_TOKEN']
|
||||
CONTRIBUTOR_IMAGE_PATH = 'contributor_leaderboard.png'
|
||||
USER_ENGAGEMENT_IMAGE_PATH = 'engagement_leaderboard.png'
|
||||
if __name__ == "__main__":
|
||||
GITHUB_TOKEN = os.environ["GITHUB_TOKEN"]
|
||||
CONTRIBUTOR_IMAGE_PATH = "contributor_leaderboard.png"
|
||||
USER_ENGAGEMENT_IMAGE_PATH = "engagement_leaderboard.png"
|
||||
ORG_NAME = "hpcaitech"
|
||||
|
||||
# get all open source repositories
|
||||
@ -527,17 +532,19 @@ if __name__ == '__main__':
|
||||
|
||||
# generate images
|
||||
contrib_success = generate_contributor_leaderboard_image(GITHUB_TOKEN, ORG_NAME, REPO_LIST, CONTRIBUTOR_IMAGE_PATH)
|
||||
engagement_success = generate_user_engagement_leaderboard_image(GITHUB_TOKEN, ORG_NAME, REPO_LIST, USER_ENGAGEMENT_IMAGE_PATH)
|
||||
engagement_success = generate_user_engagement_leaderboard_image(
|
||||
GITHUB_TOKEN, ORG_NAME, REPO_LIST, USER_ENGAGEMENT_IMAGE_PATH
|
||||
)
|
||||
|
||||
# upload images
|
||||
APP_ID = os.environ['LARK_APP_ID']
|
||||
APP_SECRET = os.environ['LARK_APP_SECRET']
|
||||
APP_ID = os.environ["LARK_APP_ID"]
|
||||
APP_SECRET = os.environ["LARK_APP_SECRET"]
|
||||
LARK_TENANT_TOKEN = generate_lark_tenant_access_token(app_id=APP_ID, app_secret=APP_SECRET)
|
||||
contributor_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, CONTRIBUTOR_IMAGE_PATH)
|
||||
user_engagement_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, USER_ENGAGEMENT_IMAGE_PATH)
|
||||
|
||||
# send message to lark
|
||||
LARK_WEBHOOK_URL = os.environ['LARK_WEBHOOK_URL']
|
||||
LARK_WEBHOOK_URL = os.environ["LARK_WEBHOOK_URL"]
|
||||
message = """本周的社区榜单出炉啦!
|
||||
1. 开发贡献者榜单
|
||||
2. 用户互动榜单
|
||||
|
@ -7,27 +7,27 @@ import re
|
||||
|
||||
import requests
|
||||
|
||||
COMMIT_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/commits'
|
||||
TAGS_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/tags'
|
||||
COMMIT_API = "https://api.github.com/repos/hpcaitech/ColossalAI/commits"
|
||||
TAGS_API = "https://api.github.com/repos/hpcaitech/ColossalAI/tags"
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--out', type=str, help='output path for the release draft', required=True)
|
||||
parser.add_argument('--version', type=str, help='current version to release', required=True)
|
||||
parser.add_argument("--out", type=str, help="output path for the release draft", required=True)
|
||||
parser.add_argument("--version", type=str, help="current version to release", required=True)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_latest_tag_commit(headers=None):
|
||||
res = requests.get(url=TAGS_API, headers=headers)
|
||||
data = res.json()
|
||||
commit_hash = data[0]['commit']['sha']
|
||||
version = data[0]['name']
|
||||
commit_hash = data[0]["commit"]["sha"]
|
||||
version = data[0]["name"]
|
||||
return commit_hash, version
|
||||
|
||||
|
||||
def get_commit_info(commit_hash, headers=None):
|
||||
api = f'{COMMIT_API}/{commit_hash}'
|
||||
api = f"{COMMIT_API}/{commit_hash}"
|
||||
res = requests.get(url=api, headers=headers)
|
||||
return res.json()
|
||||
|
||||
@ -37,7 +37,7 @@ def get_all_commit_info(since, headers=None):
|
||||
results = []
|
||||
|
||||
while True:
|
||||
api = f'{COMMIT_API}?since={since}&per_page=100&page={page}'
|
||||
api = f"{COMMIT_API}?since={since}&per_page=100&page={page}"
|
||||
resp = requests.get(url=api, headers=headers)
|
||||
data = resp.json()
|
||||
|
||||
@ -53,21 +53,21 @@ def get_all_commit_info(since, headers=None):
|
||||
|
||||
def collate_release_info(commit_info_list):
|
||||
results = dict()
|
||||
pattern = pattern = r'\[.*\]'
|
||||
pattern = pattern = r"\[.*\]"
|
||||
|
||||
for commit_info in commit_info_list:
|
||||
author = commit_info['commit']['author']['name']
|
||||
author = commit_info["commit"]["author"]["name"]
|
||||
|
||||
try:
|
||||
author_url = commit_info['author']['url']
|
||||
author_url = commit_info["author"]["url"]
|
||||
except:
|
||||
# author can be None
|
||||
author_url = None
|
||||
msg = commit_info['commit']['message']
|
||||
msg = commit_info["commit"]["message"]
|
||||
match = re.search(pattern, msg)
|
||||
|
||||
if match:
|
||||
tag = match.group().lstrip('[').rstrip(']').capitalize()
|
||||
tag = match.group().lstrip("[").rstrip("]").capitalize()
|
||||
if tag not in results:
|
||||
results[tag] = []
|
||||
results[tag].append((msg, author, author_url))
|
||||
@ -89,42 +89,43 @@ def generate_release_post_markdown(current_version, last_version, release_info):
|
||||
|
||||
for msg, author, author_url in v:
|
||||
# only keep the first line
|
||||
msg = msg.split('\n')[0]
|
||||
msg = msg.split("\n")[0]
|
||||
|
||||
if author_url:
|
||||
item = f'{msg} by [{author}]({author_url})\n'
|
||||
item = f"{msg} by [{author}]({author_url})\n"
|
||||
else:
|
||||
item = f'{msg} by {author}\n'
|
||||
text.append(f'- {item}')
|
||||
item = f"{msg} by {author}\n"
|
||||
text.append(f"- {item}")
|
||||
|
||||
text.append('\n')
|
||||
text.append("\n")
|
||||
|
||||
# add full change log
|
||||
text.append(
|
||||
f'**Full Changelog**: https://github.com/hpcaitech/ColossalAI/compare/{current_version}...{last_version}')
|
||||
f"**Full Changelog**: https://github.com/hpcaitech/ColossalAI/compare/{current_version}...{last_version}"
|
||||
)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
token = os.environ['GITHUB_API_TOKEN']
|
||||
headers = {'Authorization': token}
|
||||
token = os.environ["GITHUB_API_TOKEN"]
|
||||
headers = {"Authorization": token}
|
||||
|
||||
# get previous release tag
|
||||
last_release_commit, last_version = get_latest_tag_commit(headers)
|
||||
last_release_commit_info = get_commit_info(last_release_commit, headers=headers)
|
||||
last_release_date = last_release_commit_info['commit']['author']['date']
|
||||
last_release_date = last_release_commit_info["commit"]["author"]["date"]
|
||||
|
||||
# get the commits since last release
|
||||
commit_info = get_all_commit_info(since=last_release_date, headers=headers)
|
||||
commit_info = commit_info[:-1] # remove the release commit
|
||||
commit_info = commit_info[:-1] # remove the release commit
|
||||
|
||||
# collate into markdown
|
||||
release_info = collate_release_info(commit_info)
|
||||
markdown_text = generate_release_post_markdown(args.version, last_version, release_info)
|
||||
|
||||
# write into a file
|
||||
with open(args.out, 'w') as f:
|
||||
with open(args.out, "w") as f:
|
||||
for line in markdown_text:
|
||||
f.write(line)
|
||||
|
@ -5,8 +5,8 @@ import requests
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-m', '--message', type=str)
|
||||
parser.add_argument('-u', '--url', type=str)
|
||||
parser.add_argument("-m", "--message", type=str)
|
||||
parser.add_argument("-u", "--url", type=str)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -15,6 +15,6 @@ def send_message_to_lark(message, webhook_url):
|
||||
requests.post(webhook_url, json=data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
send_message_to_lark(args.message, args.url)
|
||||
|
@ -3,3 +3,4 @@ line_length = 120
|
||||
multi_line_output=3
|
||||
include_trailing_comma = true
|
||||
ignore_comments = true
|
||||
profile = black
|
||||
|
@ -1,23 +1,31 @@
|
||||
repos:
|
||||
|
||||
- repo: https://github.com/PyCQA/autoflake
|
||||
rev: v2.2.1
|
||||
hooks:
|
||||
- id: autoflake
|
||||
name: autoflake (python)
|
||||
args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports']
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
name: sort all imports (python)
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||
rev: v0.32.0
|
||||
- repo: https://github.com/psf/black-pre-commit-mirror
|
||||
rev: 23.9.1
|
||||
hooks:
|
||||
- id: yapf
|
||||
name: yapf formatter
|
||||
args: ['--style=.style.yapf', '--parallel', '--in-place']
|
||||
- id: black
|
||||
name: black formatter
|
||||
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v13.0.1
|
||||
hooks:
|
||||
- id: clang-format
|
||||
name: clang formatter
|
||||
types_or: [c++, c]
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.3.0
|
||||
|
@ -1,5 +0,0 @@
|
||||
[style]
|
||||
based_on_style = google
|
||||
spaces_before_comment = 4
|
||||
split_before_logical_operator = true
|
||||
column_limit = 120
|
@ -27,7 +27,7 @@ def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
|
||||
def preprocess_batch(samples) -> dict:
|
||||
input_ids = torch.stack(samples)
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
return {'input_ids': input_ids, 'attention_mask': attention_mask}
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
|
||||
|
||||
def print_rank_0(*args, **kwargs) -> None:
|
||||
@ -39,32 +39,32 @@ def print_model_numel(model_dict: dict) -> None:
|
||||
B = 1024**3
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
outputs = ''
|
||||
outputs = ""
|
||||
for name, numel in model_dict.items():
|
||||
outputs += f'{name}: '
|
||||
outputs += f"{name}: "
|
||||
if numel >= B:
|
||||
outputs += f'{numel / B:.2f} B\n'
|
||||
outputs += f"{numel / B:.2f} B\n"
|
||||
elif numel >= M:
|
||||
outputs += f'{numel / M:.2f} M\n'
|
||||
outputs += f"{numel / M:.2f} M\n"
|
||||
elif numel >= K:
|
||||
outputs += f'{numel / K:.2f} K\n'
|
||||
outputs += f"{numel / K:.2f} K\n"
|
||||
else:
|
||||
outputs += f'{numel}\n'
|
||||
outputs += f"{numel}\n"
|
||||
print_rank_0(outputs)
|
||||
|
||||
|
||||
def get_gpt_config(model_name: str) -> OPTConfig:
|
||||
model_map = {
|
||||
'125m': OPTConfig.from_pretrained('facebook/opt-125m'),
|
||||
'350m': OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16),
|
||||
'700m': OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20),
|
||||
'1.3b': OPTConfig.from_pretrained('facebook/opt-1.3b'),
|
||||
'2.7b': OPTConfig.from_pretrained('facebook/opt-2.7b'),
|
||||
'3.5b': OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32),
|
||||
'5.5b': OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32),
|
||||
'6.7b': OPTConfig.from_pretrained('facebook/opt-6.7b'),
|
||||
'10b': OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32),
|
||||
'13b': OPTConfig.from_pretrained('facebook/opt-13b'),
|
||||
"125m": OPTConfig.from_pretrained("facebook/opt-125m"),
|
||||
"350m": OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16),
|
||||
"700m": OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20),
|
||||
"1.3b": OPTConfig.from_pretrained("facebook/opt-1.3b"),
|
||||
"2.7b": OPTConfig.from_pretrained("facebook/opt-2.7b"),
|
||||
"3.5b": OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32),
|
||||
"5.5b": OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32),
|
||||
"6.7b": OPTConfig.from_pretrained("facebook/opt-6.7b"),
|
||||
"10b": OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32),
|
||||
"13b": OPTConfig.from_pretrained("facebook/opt-13b"),
|
||||
}
|
||||
try:
|
||||
return model_map[model_name]
|
||||
@ -73,20 +73,20 @@ def get_gpt_config(model_name: str) -> OPTConfig:
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.strategy == 'ddp':
|
||||
if args.strategy == "ddp":
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_gemini_cpu':
|
||||
strategy = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2_cpu':
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
|
||||
elif args.strategy == 'colossalai_zero1':
|
||||
strategy = LowLevelZeroStrategy(stage=1, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero1_cpu':
|
||||
strategy = LowLevelZeroStrategy(stage=1, placement_policy='cpu')
|
||||
elif args.strategy == "colossalai_gemini":
|
||||
strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
|
||||
elif args.strategy == "colossalai_gemini_cpu":
|
||||
strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
|
||||
elif args.strategy == "colossalai_zero2":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
elif args.strategy == "colossalai_zero2_cpu":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
|
||||
elif args.strategy == "colossalai_zero1":
|
||||
strategy = LowLevelZeroStrategy(stage=1, placement_policy="cuda")
|
||||
elif args.strategy == "colossalai_zero1_cpu":
|
||||
strategy = LowLevelZeroStrategy(stage=1, placement_policy="cpu")
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
@ -103,90 +103,106 @@ def main(args):
|
||||
|
||||
if args.use_kernels:
|
||||
from coati.kernels import convert_to_xformer_model
|
||||
actor, critic, initial_model, reward_model = map(convert_to_xformer_model,
|
||||
(actor, critic, initial_model, reward_model))
|
||||
|
||||
actor, critic, initial_model, reward_model = map(
|
||||
convert_to_xformer_model, (actor, critic, initial_model, reward_model)
|
||||
)
|
||||
|
||||
actor_numel = get_model_numel(actor, strategy)
|
||||
critic_numel = get_model_numel(critic, strategy)
|
||||
initial_model_numel = get_model_numel(initial_model, strategy)
|
||||
reward_model_numel = get_model_numel(reward_model, strategy)
|
||||
print_model_numel({
|
||||
'Actor': actor_numel,
|
||||
'Critic': critic_numel,
|
||||
'Initial model': initial_model_numel,
|
||||
'Reward model': reward_model_numel
|
||||
})
|
||||
performance_evaluator = PerformanceEvaluator(actor_numel,
|
||||
critic_numel,
|
||||
initial_model_numel,
|
||||
reward_model_numel,
|
||||
enable_grad_checkpoint=False,
|
||||
ignore_episodes=1)
|
||||
print_model_numel(
|
||||
{
|
||||
"Actor": actor_numel,
|
||||
"Critic": critic_numel,
|
||||
"Initial model": initial_model_numel,
|
||||
"Reward model": reward_model_numel,
|
||||
}
|
||||
)
|
||||
performance_evaluator = PerformanceEvaluator(
|
||||
actor_numel,
|
||||
critic_numel,
|
||||
initial_model_numel,
|
||||
reward_model_numel,
|
||||
enable_grad_checkpoint=False,
|
||||
ignore_episodes=1,
|
||||
)
|
||||
|
||||
if args.strategy.startswith('colossalai'):
|
||||
if args.strategy.startswith("colossalai"):
|
||||
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
|
||||
else:
|
||||
actor_optim = Adam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = Adam(critic.parameters(), lr=5e-6)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device())
|
||||
dataloader = DataLoader(random_prompts,
|
||||
batch_size=args.experience_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=preprocess_batch)
|
||||
dataloader = DataLoader(
|
||||
random_prompts, batch_size=args.experience_batch_size, shuffle=True, collate_fn=preprocess_batch
|
||||
)
|
||||
|
||||
trainer = PPOTrainer(strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
ptx_coef=0,
|
||||
train_batch_size=args.train_batch_size,
|
||||
offload_inference_models=args.offload_inference_models,
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
use_cache=True,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
callbacks=[performance_evaluator])
|
||||
trainer = PPOTrainer(
|
||||
strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
ptx_coef=0,
|
||||
train_batch_size=args.train_batch_size,
|
||||
offload_inference_models=args.offload_inference_models,
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
use_cache=True,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
callbacks=[performance_evaluator],
|
||||
)
|
||||
|
||||
trainer.fit(prompt_dataloader=dataloader,
|
||||
pretrain_dataloader=None,
|
||||
num_episodes=args.num_episodes,
|
||||
num_update_steps=args.num_update_steps,
|
||||
num_collect_steps=args.num_collect_steps)
|
||||
trainer.fit(
|
||||
prompt_dataloader=dataloader,
|
||||
pretrain_dataloader=None,
|
||||
num_episodes=args.num_episodes,
|
||||
num_update_steps=args.num_update_steps,
|
||||
num_collect_steps=args.num_collect_steps,
|
||||
)
|
||||
|
||||
print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
|
||||
print_rank_0(f"Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', default='125m')
|
||||
parser.add_argument('--critic_model', default='125m')
|
||||
parser.add_argument('--strategy',
|
||||
choices=[
|
||||
'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2',
|
||||
'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu'
|
||||
],
|
||||
default='ddp')
|
||||
parser.add_argument('--num_episodes', type=int, default=3)
|
||||
parser.add_argument('--num_collect_steps', type=int, default=8)
|
||||
parser.add_argument('--num_update_steps', type=int, default=1)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0)
|
||||
parser.add_argument('--cuda_mem_frac', type=float, default=1.0)
|
||||
parser.add_argument('--offload_inference_models', action='store_true', default=False)
|
||||
parser.add_argument('--use_kernels', action='store_true', default=False)
|
||||
parser.add_argument("--model", default="125m")
|
||||
parser.add_argument("--critic_model", default="125m")
|
||||
parser.add_argument(
|
||||
"--strategy",
|
||||
choices=[
|
||||
"ddp",
|
||||
"colossalai_gemini",
|
||||
"colossalai_gemini_cpu",
|
||||
"colossalai_zero2",
|
||||
"colossalai_zero2_cpu",
|
||||
"colossalai_zero1",
|
||||
"colossalai_zero1_cpu",
|
||||
],
|
||||
default="ddp",
|
||||
)
|
||||
parser.add_argument("--num_episodes", type=int, default=3)
|
||||
parser.add_argument("--num_collect_steps", type=int, default=8)
|
||||
parser.add_argument("--num_update_steps", type=int, default=1)
|
||||
parser.add_argument("--train_batch_size", type=int, default=8)
|
||||
parser.add_argument("--experience_batch_size", type=int, default=8)
|
||||
parser.add_argument("--lora_rank", type=int, default=0)
|
||||
parser.add_argument("--cuda_mem_frac", type=float, default=1.0)
|
||||
parser.add_argument("--offload_inference_models", action="store_true", default=False)
|
||||
parser.add_argument("--use_kernels", action="store_true", default=False)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -22,13 +22,13 @@ from transformers.modeling_utils import no_init_weights
|
||||
|
||||
def get_free_port():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', 0))
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def get_local_ip():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||
s.connect(('8.8.8.8', 80))
|
||||
s.connect(("8.8.8.8", 80))
|
||||
return s.getsockname()[0]
|
||||
|
||||
|
||||
@ -36,22 +36,25 @@ def main(args):
|
||||
master_addr = str(get_local_ip())
|
||||
# trainer_env_info
|
||||
trainer_port = str(get_free_port())
|
||||
env_info_trainers = [{
|
||||
'local_rank': '0',
|
||||
'rank': str(rank),
|
||||
'world_size': str(args.num_trainers),
|
||||
'master_port': trainer_port,
|
||||
'master_addr': master_addr
|
||||
} for rank in range(args.num_trainers)]
|
||||
env_info_trainers = [
|
||||
{
|
||||
"local_rank": "0",
|
||||
"rank": str(rank),
|
||||
"world_size": str(args.num_trainers),
|
||||
"master_port": trainer_port,
|
||||
"master_addr": master_addr,
|
||||
}
|
||||
for rank in range(args.num_trainers)
|
||||
]
|
||||
|
||||
# maker_env_info
|
||||
maker_port = str(get_free_port())
|
||||
env_info_maker = {
|
||||
'local_rank': '0',
|
||||
'rank': '0',
|
||||
'world_size': '1',
|
||||
'master_port': maker_port,
|
||||
'master_addr': master_addr
|
||||
"local_rank": "0",
|
||||
"rank": "0",
|
||||
"world_size": "1",
|
||||
"master_port": maker_port,
|
||||
"master_addr": master_addr,
|
||||
}
|
||||
|
||||
# configure tokenizer
|
||||
@ -63,21 +66,27 @@ def main(args):
|
||||
critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
|
||||
actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
|
||||
critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
|
||||
reward_model = get_reward_model_from_args(args.critic_model,
|
||||
config=critic_cfg).requires_grad_(False).half().cuda()
|
||||
if args.initial_model_quant_ckpt is not None and args.model == 'llama':
|
||||
reward_model = (
|
||||
get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
|
||||
)
|
||||
if args.initial_model_quant_ckpt is not None and args.model == "llama":
|
||||
# quantize initial model
|
||||
with low_resource_init(), no_init_weights():
|
||||
initial_model = get_actor_from_args(args.model, config=actor_cfg)
|
||||
initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
|
||||
args.quant_group_size).cuda().requires_grad_(False)
|
||||
initial_model.model = (
|
||||
llama_load_quant(
|
||||
initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
|
||||
)
|
||||
.cuda()
|
||||
.requires_grad_(False)
|
||||
)
|
||||
else:
|
||||
initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
|
||||
return actor, critic, reward_model, initial_model
|
||||
|
||||
# configure Experience Maker
|
||||
experience_holder_ref = ExperienceMakerHolder.options(name="maker0", num_gpus=1, max_concurrency=2).remote(
|
||||
detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)],
|
||||
detached_trainer_name_list=[f"trainer{i}" for i in range(args.num_trainers)],
|
||||
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
|
||||
model_fn=model_fn,
|
||||
env_info=env_info_maker,
|
||||
@ -97,15 +106,18 @@ def main(args):
|
||||
|
||||
def trainer_model_fn():
|
||||
actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
|
||||
critic = get_critic_from_args(args.critic_model,
|
||||
config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda()
|
||||
critic = (
|
||||
get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain))
|
||||
.half()
|
||||
.cuda()
|
||||
)
|
||||
return actor, critic
|
||||
|
||||
# configure Trainer
|
||||
trainer_refs = [
|
||||
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
|
||||
experience_maker_holder_name_list=[
|
||||
f'maker{x}' for x in get_receivers_per_sender(i, args.num_trainers, 1, allow_idle_sender=True)
|
||||
f"maker{x}" for x in get_receivers_per_sender(i, args.num_trainers, 1, allow_idle_sender=True)
|
||||
],
|
||||
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
|
||||
model_fn=trainer_model_fn,
|
||||
@ -114,7 +126,8 @@ def main(args):
|
||||
buffer_limit=16,
|
||||
eval_performance=True,
|
||||
debug=args.debug,
|
||||
) for i, env_info_trainer in enumerate(env_info_trainers)
|
||||
)
|
||||
for i, env_info_trainer in enumerate(env_info_trainers)
|
||||
]
|
||||
|
||||
dataset_size = args.experience_batch_size * 4
|
||||
@ -122,7 +135,7 @@ def main(args):
|
||||
def data_gen_fn():
|
||||
input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
|
||||
attn_mask = torch.ones_like(input_ids)
|
||||
return {'input_ids': input_ids, 'attention_mask': attn_mask}
|
||||
return {"input_ids": input_ids, "attention_mask": attn_mask}
|
||||
|
||||
def build_dataloader(size):
|
||||
dataset = [data_gen_fn() for _ in range(size)]
|
||||
@ -138,8 +151,10 @@ def main(args):
|
||||
wait_tasks = []
|
||||
|
||||
wait_tasks.append(
|
||||
experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size),
|
||||
num_steps=args.experience_steps))
|
||||
experience_holder_ref.workingloop.remote(
|
||||
partial(build_dataloader, dataset_size), num_steps=args.experience_steps
|
||||
)
|
||||
)
|
||||
|
||||
total_steps = args.experience_batch_size * args.experience_steps // (args.num_trainers * args.train_batch_size)
|
||||
for trainer_ref in trainer_refs:
|
||||
@ -148,31 +163,30 @@ def main(args):
|
||||
ray.get(wait_tasks)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--num_trainers', type=int, default=1)
|
||||
parser.add_argument('--trainer_strategy',
|
||||
choices=[
|
||||
'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
|
||||
'colossalai_zero2_cpu'
|
||||
],
|
||||
default='ddp')
|
||||
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--critic_pretrain', type=str, default=None)
|
||||
parser.add_argument('--experience_steps', type=int, default=4)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--train_epochs', type=int, default=1)
|
||||
parser.add_argument('--update_steps', type=int, default=2)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument("--num_trainers", type=int, default=1)
|
||||
parser.add_argument(
|
||||
"--trainer_strategy",
|
||||
choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
|
||||
default="ddp",
|
||||
)
|
||||
parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
|
||||
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
|
||||
parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--critic_pretrain", type=str, default=None)
|
||||
parser.add_argument("--experience_steps", type=int, default=4)
|
||||
parser.add_argument("--experience_batch_size", type=int, default=8)
|
||||
parser.add_argument("--train_epochs", type=int, default=1)
|
||||
parser.add_argument("--update_steps", type=int, default=2)
|
||||
parser.add_argument("--train_batch_size", type=int, default=8)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
|
||||
parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
|
||||
parser.add_argument('--quant_bits', type=int, default=4)
|
||||
parser.add_argument('--quant_group_size', type=int, default=128)
|
||||
parser.add_argument('--debug', action='store_true')
|
||||
parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
|
||||
parser.add_argument("--quant_bits", type=int, default=4)
|
||||
parser.add_argument("--quant_group_size", type=int, default=128)
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
args = parser.parse_args()
|
||||
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
|
||||
main(args)
|
||||
|
@ -22,13 +22,13 @@ from transformers.modeling_utils import no_init_weights
|
||||
|
||||
def get_free_port():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', 0))
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def get_local_ip():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||
s.connect(('8.8.8.8', 80))
|
||||
s.connect(("8.8.8.8", 80))
|
||||
return s.getsockname()[0]
|
||||
|
||||
|
||||
@ -36,23 +36,29 @@ def main(args):
|
||||
master_addr = str(get_local_ip())
|
||||
# trainer_env_info
|
||||
trainer_port = str(get_free_port())
|
||||
env_info_trainers = [{
|
||||
'local_rank': '0',
|
||||
'rank': str(rank),
|
||||
'world_size': str(args.num_trainers),
|
||||
'master_port': trainer_port,
|
||||
'master_addr': master_addr
|
||||
} for rank in range(args.num_trainers)]
|
||||
env_info_trainers = [
|
||||
{
|
||||
"local_rank": "0",
|
||||
"rank": str(rank),
|
||||
"world_size": str(args.num_trainers),
|
||||
"master_port": trainer_port,
|
||||
"master_addr": master_addr,
|
||||
}
|
||||
for rank in range(args.num_trainers)
|
||||
]
|
||||
|
||||
# maker_env_info
|
||||
maker_port = str(get_free_port())
|
||||
env_info_makers = [{
|
||||
'local_rank': '0',
|
||||
'rank': str(rank),
|
||||
'world_size': str(args.num_makers),
|
||||
'master_port': maker_port,
|
||||
'master_addr': master_addr
|
||||
} for rank in range(args.num_makers)]
|
||||
env_info_makers = [
|
||||
{
|
||||
"local_rank": "0",
|
||||
"rank": str(rank),
|
||||
"world_size": str(args.num_makers),
|
||||
"master_port": maker_port,
|
||||
"master_addr": master_addr,
|
||||
}
|
||||
for rank in range(args.num_makers)
|
||||
]
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
|
||||
@ -63,14 +69,20 @@ def main(args):
|
||||
critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
|
||||
actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
|
||||
critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
|
||||
reward_model = get_reward_model_from_args(args.critic_model,
|
||||
config=critic_cfg).requires_grad_(False).half().cuda()
|
||||
if args.initial_model_quant_ckpt is not None and args.model == 'llama':
|
||||
reward_model = (
|
||||
get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
|
||||
)
|
||||
if args.initial_model_quant_ckpt is not None and args.model == "llama":
|
||||
# quantize initial model
|
||||
with low_resource_init(), no_init_weights():
|
||||
initial_model = get_actor_from_args(args.model, config=actor_cfg)
|
||||
initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
|
||||
args.quant_group_size).cuda().requires_grad_(False)
|
||||
initial_model.model = (
|
||||
llama_load_quant(
|
||||
initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
|
||||
)
|
||||
.cuda()
|
||||
.requires_grad_(False)
|
||||
)
|
||||
else:
|
||||
initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
|
||||
return actor, critic, reward_model, initial_model
|
||||
@ -79,7 +91,7 @@ def main(args):
|
||||
experience_holder_refs = [
|
||||
ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote(
|
||||
detached_trainer_name_list=[
|
||||
f'trainer{x}'
|
||||
f"trainer{x}"
|
||||
for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False)
|
||||
],
|
||||
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
|
||||
@ -103,8 +115,11 @@ def main(args):
|
||||
|
||||
def trainer_model_fn():
|
||||
actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
|
||||
critic = get_critic_from_args(args.critic_model,
|
||||
config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda()
|
||||
critic = (
|
||||
get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain))
|
||||
.half()
|
||||
.cuda()
|
||||
)
|
||||
return actor, critic
|
||||
|
||||
# configure Trainer
|
||||
@ -130,7 +145,7 @@ def main(args):
|
||||
def data_gen_fn():
|
||||
input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
|
||||
attn_mask = torch.ones_like(input_ids)
|
||||
return {'input_ids': input_ids, 'attention_mask': attn_mask}
|
||||
return {"input_ids": input_ids, "attention_mask": attn_mask}
|
||||
|
||||
def build_dataloader(size):
|
||||
dataset = [data_gen_fn() for _ in range(size)]
|
||||
@ -147,43 +162,48 @@ def main(args):
|
||||
|
||||
for experience_holder_ref in experience_holder_refs:
|
||||
wait_tasks.append(
|
||||
experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size),
|
||||
num_steps=args.experience_steps))
|
||||
experience_holder_ref.workingloop.remote(
|
||||
partial(build_dataloader, dataset_size), num_steps=args.experience_steps
|
||||
)
|
||||
)
|
||||
|
||||
total_steps = args.experience_batch_size * args.experience_steps * \
|
||||
args.num_makers // (args.num_trainers * args.train_batch_size)
|
||||
total_steps = (
|
||||
args.experience_batch_size
|
||||
* args.experience_steps
|
||||
* args.num_makers
|
||||
// (args.num_trainers * args.train_batch_size)
|
||||
)
|
||||
for trainer_ref in trainer_refs:
|
||||
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
|
||||
|
||||
ray.get(wait_tasks)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--num_makers', type=int, default=1)
|
||||
parser.add_argument('--num_trainers', type=int, default=1)
|
||||
parser.add_argument('--trainer_strategy',
|
||||
choices=[
|
||||
'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
|
||||
'colossalai_zero2_cpu'
|
||||
],
|
||||
default='ddp')
|
||||
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--critic_pretrain', type=str, default=None)
|
||||
parser.add_argument('--experience_steps', type=int, default=4)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--train_epochs', type=int, default=1)
|
||||
parser.add_argument('--update_steps', type=int, default=2)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument("--num_makers", type=int, default=1)
|
||||
parser.add_argument("--num_trainers", type=int, default=1)
|
||||
parser.add_argument(
|
||||
"--trainer_strategy",
|
||||
choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
|
||||
default="ddp",
|
||||
)
|
||||
parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
|
||||
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
|
||||
parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--critic_pretrain", type=str, default=None)
|
||||
parser.add_argument("--experience_steps", type=int, default=4)
|
||||
parser.add_argument("--experience_batch_size", type=int, default=8)
|
||||
parser.add_argument("--train_epochs", type=int, default=1)
|
||||
parser.add_argument("--update_steps", type=int, default=2)
|
||||
parser.add_argument("--train_batch_size", type=int, default=8)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
|
||||
parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
|
||||
parser.add_argument('--quant_bits', type=int, default=4)
|
||||
parser.add_argument('--quant_group_size', type=int, default=128)
|
||||
parser.add_argument('--debug', action='store_true')
|
||||
parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
|
||||
parser.add_argument("--quant_bits", type=int, default=4)
|
||||
parser.add_argument("--quant_group_size", type=int, default=128)
|
||||
parser.add_argument("--debug", action="store_true")
|
||||
args = parser.parse_args()
|
||||
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
|
||||
main(args)
|
||||
|
@ -4,7 +4,10 @@ from .sft_dataset import SFTDataset, SupervisedDataset
|
||||
from .utils import is_rank_0
|
||||
|
||||
__all__ = [
|
||||
'RmStaticDataset', 'HhRlhfDataset',
|
||||
'SFTDataset', 'SupervisedDataset',
|
||||
'PromptDataset', 'is_rank_0',
|
||||
"RmStaticDataset",
|
||||
"HhRlhfDataset",
|
||||
"SFTDataset",
|
||||
"SupervisedDataset",
|
||||
"PromptDataset",
|
||||
"is_rank_0",
|
||||
]
|
||||
|
@ -49,7 +49,7 @@ class Conversation:
|
||||
|
||||
def to_gradio_chatbot(self):
|
||||
ret = []
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
||||
if i % 2 == 0:
|
||||
ret.append([msg, None])
|
||||
else:
|
||||
@ -57,12 +57,14 @@ class Conversation:
|
||||
return ret
|
||||
|
||||
def copy(self):
|
||||
return Conversation(system=self.system,
|
||||
roles=self.roles,
|
||||
messages=[[x, y] for x, y in self.messages],
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep)
|
||||
return Conversation(
|
||||
system=self.system,
|
||||
roles=self.roles,
|
||||
messages=[[x, y] for x, y in self.messages],
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep,
|
||||
)
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
@ -70,7 +72,7 @@ class Conversation:
|
||||
"roles": self.roles,
|
||||
"messages": self.messages,
|
||||
"offset": self.offset,
|
||||
"sep": self.sep
|
||||
"sep": self.sep,
|
||||
}
|
||||
|
||||
|
||||
|
@ -13,11 +13,13 @@ from .utils import jload
|
||||
class PromptDataset(Dataset):
|
||||
"""Dataset for supervised fine-tuning."""
|
||||
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
max_datasets_size: int = None,
|
||||
max_length: int = 96):
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str,
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
max_datasets_size: int = None,
|
||||
max_length: int = 96,
|
||||
):
|
||||
super(PromptDataset, self).__init__()
|
||||
self.keyed_prompt = defaultdict(list)
|
||||
self.logger = get_dist_logger()
|
||||
@ -30,11 +32,9 @@ class PromptDataset(Dataset):
|
||||
list_data_dict = list_data_dict[:max_datasets_size]
|
||||
|
||||
instructions = [data_dict["instruction"] for data_dict in list_data_dict]
|
||||
tokens = tokenizer(instructions,
|
||||
return_tensors='pt',
|
||||
max_length=max_length,
|
||||
padding='max_length',
|
||||
truncation=True)
|
||||
tokens = tokenizer(
|
||||
instructions, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True
|
||||
)
|
||||
for k, tensor in tokens.items():
|
||||
self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind()
|
||||
|
||||
|
@ -20,44 +20,31 @@ class RmStaticDataset(Dataset):
|
||||
|
||||
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
|
||||
super().__init__()
|
||||
self.end_token = tokenizer.eos_token \
|
||||
if special_token is None else special_token
|
||||
self.end_token = tokenizer.eos_token if special_token is None else special_token
|
||||
|
||||
chosen = [
|
||||
data["prompt"] + data["chosen"] + self.end_token
|
||||
for data in tqdm(dataset, disable=not is_rank_0())
|
||||
]
|
||||
chosen_token = tokenizer(chosen,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.chosen = {
|
||||
"input_ids": chosen_token["input_ids"],
|
||||
"attention_mask": chosen_token["attention_mask"]
|
||||
}
|
||||
chosen = [data["prompt"] + data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
|
||||
chosen_token = tokenizer(
|
||||
chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
|
||||
|
||||
reject = [
|
||||
data["prompt"] + data["rejected"] + self.end_token
|
||||
for data in tqdm(dataset, disable=not is_rank_0())
|
||||
]
|
||||
reject_token = tokenizer(reject,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.reject = {
|
||||
"input_ids": reject_token["input_ids"],
|
||||
"attention_mask": reject_token["attention_mask"]
|
||||
}
|
||||
reject = [data["prompt"] + data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
|
||||
reject_token = tokenizer(
|
||||
reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
|
||||
|
||||
def __len__(self):
|
||||
length = self.chosen["input_ids"].shape[0]
|
||||
return length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
|
||||
self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
|
||||
return (
|
||||
self.chosen["input_ids"][idx],
|
||||
self.chosen["attention_mask"][idx],
|
||||
self.reject["input_ids"][idx],
|
||||
self.reject["attention_mask"][idx],
|
||||
)
|
||||
|
||||
|
||||
# Anthropic/hh-rlhf
|
||||
@ -74,41 +61,28 @@ class HhRlhfDataset(Dataset):
|
||||
|
||||
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
|
||||
super().__init__()
|
||||
self.end_token = tokenizer.eos_token \
|
||||
if special_token is None else special_token
|
||||
self.end_token = tokenizer.eos_token if special_token is None else special_token
|
||||
|
||||
chosen = [
|
||||
data["chosen"] + self.end_token
|
||||
for data in tqdm(dataset, disable=not is_rank_0())
|
||||
]
|
||||
chosen_token = tokenizer(chosen,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.chosen = {
|
||||
"input_ids": chosen_token["input_ids"],
|
||||
"attention_mask": chosen_token["attention_mask"]
|
||||
}
|
||||
chosen = [data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
|
||||
chosen_token = tokenizer(
|
||||
chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
|
||||
|
||||
reject = [
|
||||
data["rejected"] + self.end_token
|
||||
for data in tqdm(dataset, disable=not is_rank_0())
|
||||
]
|
||||
reject_token = tokenizer(reject,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.reject = {
|
||||
"input_ids": reject_token["input_ids"],
|
||||
"attention_mask": reject_token["attention_mask"]
|
||||
}
|
||||
reject = [data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
|
||||
reject_token = tokenizer(
|
||||
reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
|
||||
|
||||
def __len__(self):
|
||||
length = self.chosen["input_ids"].shape[0]
|
||||
return length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
|
||||
self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
|
||||
return (
|
||||
self.chosen["input_ids"][idx],
|
||||
self.chosen["attention_mask"][idx],
|
||||
self.reject["input_ids"][idx],
|
||||
self.reject["attention_mask"][idx],
|
||||
)
|
||||
|
@ -16,10 +16,11 @@ import copy
|
||||
from typing import Dict, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import PreTrainedTokenizer
|
||||
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .utils import is_rank_0, jload
|
||||
@ -28,32 +29,33 @@ logger = get_dist_logger()
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
PROMPT_DICT = {
|
||||
"prompt_input": ("Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
|
||||
"prompt_no_input": ("Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\n{instruction}\n\n### Response:"),
|
||||
"prompt_input": (
|
||||
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
|
||||
),
|
||||
"prompt_no_input": (
|
||||
"Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\n{instruction}\n\n### Response:"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _preprocess(sources: Sequence[str],
|
||||
targets: Sequence[str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_length: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def _preprocess(
|
||||
sources: Sequence[str],
|
||||
targets: Sequence[str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_length: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Preprocess the data by tokenizing."""
|
||||
sequences = [s + t for s, t in zip(sources, targets)]
|
||||
sequences_token = tokenizer(sequences,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
sources_token = tokenizer(sources,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
sequences_token = tokenizer(
|
||||
sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
sources_token = tokenizer(
|
||||
sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
labels = copy.deepcopy(sequences_token["input_ids"])
|
||||
for i in range(labels.shape[0]):
|
||||
@ -64,18 +66,19 @@ def _preprocess(sources: Sequence[str],
|
||||
labels[i][:source_len] = IGNORE_INDEX
|
||||
elif tokenizer.padding_side == "left":
|
||||
# |pad|prompt|completion|eos|
|
||||
labels[i][pad_len:pad_len + source_len] = IGNORE_INDEX
|
||||
labels[i][pad_len : pad_len + source_len] = IGNORE_INDEX
|
||||
else:
|
||||
raise RuntimeError()
|
||||
|
||||
return sequences_token["input_ids"], labels, sequences_token["attention_mask"]
|
||||
|
||||
|
||||
def _preprocess_chatglm(sources: Sequence[str],
|
||||
targets: Sequence[str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_length: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def _preprocess_chatglm(
|
||||
sources: Sequence[str],
|
||||
targets: Sequence[str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_length: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Preprocess the data by tokenizing.
|
||||
None for attention mask, ChatGLM will calculate attention mask according to input ids
|
||||
@ -90,15 +93,15 @@ def _preprocess_chatglm(sources: Sequence[str],
|
||||
# truncate
|
||||
sp_token_list = [tokenizer.gmask_token_id, tokenizer.bos_token_id]
|
||||
truncate_length = max(0, len(input_id) - max_length)
|
||||
input_id = input_id[truncate_length: ]
|
||||
input_id = input_id[truncate_length:]
|
||||
if truncate_length == len(source_id) + 1:
|
||||
input_id = sp_token_list + input_id[1: ]
|
||||
input_id = sp_token_list + input_id[1:]
|
||||
elif truncate_length > len(source_id) + 1:
|
||||
input_id = sp_token_list + input_id[2: ]
|
||||
input_id = sp_token_list + input_id[2:]
|
||||
|
||||
context_length = input_id.index(tokenizer.bos_token_id)
|
||||
mask_position = context_length - 1
|
||||
label = [IGNORE_INDEX] * context_length + input_id[mask_position+1:]
|
||||
label = [IGNORE_INDEX] * context_length + input_id[mask_position + 1 :]
|
||||
|
||||
pad_len = max_length - len(input_id)
|
||||
input_id = input_id + [tokenizer.pad_token_id] * pad_len
|
||||
@ -117,25 +120,18 @@ class SFTDataset(Dataset):
|
||||
max_length: max length of input
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset: Dict,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_length: int = 512
|
||||
) -> None:
|
||||
def __init__(self, dataset: Dict, tokenizer: PreTrainedTokenizer, max_length: int = 512) -> None:
|
||||
super().__init__()
|
||||
self.input_ids = []
|
||||
|
||||
sources = [data["prompt"] for data in dataset]
|
||||
targets = [
|
||||
data["completion"] + tokenizer.eos_token
|
||||
for data in tqdm(dataset, disable=not is_rank_0())
|
||||
]
|
||||
targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())]
|
||||
if isinstance(tokenizer, ChatGLMTokenizer):
|
||||
self.input_ids, self.labels, self.attention_mask = \
|
||||
_preprocess_chatglm(sources, targets, tokenizer, max_length)
|
||||
self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
|
||||
sources, targets, tokenizer, max_length
|
||||
)
|
||||
else:
|
||||
self.input_ids, self.labels, self.attention_mask = \
|
||||
_preprocess(sources, targets, tokenizer, max_length)
|
||||
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
|
||||
|
||||
def __len__(self):
|
||||
length = self.input_ids.shape[0]
|
||||
@ -143,22 +139,17 @@ class SFTDataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.attention_mask is not None:
|
||||
return dict(input_ids=self.input_ids[idx],
|
||||
labels=self.labels[idx],
|
||||
attention_mask=self.attention_mask[idx])
|
||||
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
|
||||
else:
|
||||
return dict(input_ids=self.input_ids[idx],
|
||||
labels=self.labels[idx])
|
||||
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
|
||||
|
||||
|
||||
class SupervisedDataset(Dataset):
|
||||
"""Dataset for supervised fine-tuning."""
|
||||
|
||||
def __init__(self,
|
||||
data_path: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
max_datasets_size: int = None,
|
||||
max_length: int = 512):
|
||||
def __init__(
|
||||
self, data_path: str, tokenizer: PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512
|
||||
):
|
||||
super().__init__()
|
||||
logger.info("Loading data...")
|
||||
list_data_dict = jload(data_path)
|
||||
@ -174,18 +165,15 @@ class SupervisedDataset(Dataset):
|
||||
prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example)
|
||||
for example in list_data_dict
|
||||
]
|
||||
targets = [
|
||||
example['output'] + tokenizer.eos_token
|
||||
for example in list_data_dict
|
||||
]
|
||||
targets = [example["output"] + tokenizer.eos_token for example in list_data_dict]
|
||||
|
||||
logger.info("Tokenizing inputs... This may take some time...")
|
||||
if isinstance(tokenizer, ChatGLMTokenizer):
|
||||
self.input_ids, self.labels, self.attention_mask = \
|
||||
_preprocess_chatglm(sources, targets, tokenizer, max_length)
|
||||
self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
|
||||
sources, targets, tokenizer, max_length
|
||||
)
|
||||
else:
|
||||
self.input_ids, self.labels, self.attention_mask = \
|
||||
_preprocess(sources, targets, tokenizer, max_length)
|
||||
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
|
||||
|
||||
def __len__(self):
|
||||
length = self.input_ids.shape[0]
|
||||
@ -193,9 +181,6 @@ class SupervisedDataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.attention_mask is not None:
|
||||
return dict(input_ids=self.input_ids[idx],
|
||||
labels=self.labels[idx],
|
||||
attention_mask=self.attention_mask[idx])
|
||||
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
|
||||
else:
|
||||
return dict(input_ids=self.input_ids[idx],
|
||||
labels=self.labels[idx])
|
||||
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .base import ExperienceBuffer
|
||||
from .naive import NaiveExperienceBuffer
|
||||
|
||||
__all__ = ['ExperienceBuffer', 'NaiveExperienceBuffer']
|
||||
__all__ = ["ExperienceBuffer", "NaiveExperienceBuffer"]
|
||||
|
@ -7,9 +7,9 @@ from coati.experience_maker.base import Experience
|
||||
class ExperienceBuffer(ABC):
|
||||
"""Experience buffer base class. It stores experience.
|
||||
|
||||
Args:
|
||||
sample_batch_size (int): Batch size when sampling.
|
||||
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
|
||||
Args:
|
||||
sample_batch_size (int): Batch size when sampling.
|
||||
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
|
||||
"""
|
||||
|
||||
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
|
||||
|
@ -11,23 +11,23 @@ from .utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
class NaiveExperienceBuffer(ExperienceBuffer):
|
||||
"""Naive experience buffer class. It stores experience.
|
||||
|
||||
Args:
|
||||
sample_batch_size (int): Batch size when sampling.
|
||||
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
|
||||
cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True.
|
||||
Args:
|
||||
sample_batch_size (int): Batch size when sampling.
|
||||
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
|
||||
cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None:
|
||||
super().__init__(sample_batch_size, limit)
|
||||
self.cpu_offload = cpu_offload
|
||||
self.target_device = torch.device(f'cuda:{torch.cuda.current_device()}')
|
||||
self.target_device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
# TODO(ver217): add prefetch
|
||||
self.items: List[BufferItem] = []
|
||||
|
||||
@torch.no_grad()
|
||||
def append(self, experience: Experience) -> None:
|
||||
if self.cpu_offload:
|
||||
experience.to_device(torch.device('cpu'))
|
||||
experience.to_device(torch.device("cpu"))
|
||||
items = split_experience_batch(experience)
|
||||
self.items.extend(items)
|
||||
if self.limit > 0:
|
||||
|
@ -21,6 +21,7 @@ class BufferItem:
|
||||
|
||||
"A" is the number of actions.
|
||||
"""
|
||||
|
||||
sequences: torch.Tensor
|
||||
action_log_probs: torch.Tensor
|
||||
values: torch.Tensor
|
||||
@ -33,8 +34,7 @@ class BufferItem:
|
||||
def split_experience_batch(experience: Experience) -> List[BufferItem]:
|
||||
batch_size = experience.sequences.size(0)
|
||||
batch_kwargs = [{} for _ in range(batch_size)]
|
||||
keys = ('sequences', 'action_log_probs', 'values',
|
||||
'reward', 'advantages', 'attention_mask', 'action_mask')
|
||||
keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask")
|
||||
for key in keys:
|
||||
value = getattr(experience, key)
|
||||
if isinstance(value, torch.Tensor):
|
||||
@ -49,22 +49,21 @@ def split_experience_batch(experience: Experience) -> List[BufferItem]:
|
||||
return items
|
||||
|
||||
|
||||
def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
|
||||
assert side in ('left', 'right')
|
||||
def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left") -> torch.Tensor:
|
||||
assert side in ("left", "right")
|
||||
max_len = max(seq.size(0) for seq in sequences)
|
||||
padded_sequences = []
|
||||
for seq in sequences:
|
||||
pad_len = max_len - seq.size(0)
|
||||
padding = (pad_len, 0) if side == 'left' else (0, pad_len)
|
||||
padding = (pad_len, 0) if side == "left" else (0, pad_len)
|
||||
padded_sequences.append(F.pad(seq, padding))
|
||||
return torch.stack(padded_sequences, dim=0)
|
||||
|
||||
|
||||
def make_experience_batch(items: List[BufferItem]) -> Experience:
|
||||
kwargs = {}
|
||||
to_pad_keys = set(('action_log_probs', 'action_mask'))
|
||||
keys = ('sequences', 'action_log_probs', 'values',
|
||||
'reward', 'advantages', 'attention_mask', 'action_mask')
|
||||
to_pad_keys = set(("action_log_probs", "action_mask"))
|
||||
keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask")
|
||||
for key in keys:
|
||||
vals = [getattr(item, key) for item in items]
|
||||
if key in to_pad_keys:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .base import Experience, ExperienceMaker
|
||||
from .naive import NaiveExperienceMaker
|
||||
|
||||
__all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker']
|
||||
__all__ = ["Experience", "ExperienceMaker", "NaiveExperienceMaker"]
|
||||
|
@ -24,6 +24,7 @@ class Experience:
|
||||
|
||||
"A" is the number of actions.
|
||||
"""
|
||||
|
||||
sequences: torch.Tensor
|
||||
action_log_probs: torch.Tensor
|
||||
values: torch.Tensor
|
||||
@ -58,13 +59,9 @@ class Experience:
|
||||
|
||||
|
||||
class ExperienceMaker(ABC):
|
||||
|
||||
def __init__(self,
|
||||
actor: Actor,
|
||||
critic: nn.Module,
|
||||
reward_model: nn.Module,
|
||||
initial_model: Actor,
|
||||
kl_coef: float = 0.1) -> None:
|
||||
def __init__(
|
||||
self, actor: Actor, critic: nn.Module, reward_model: nn.Module, initial_model: Actor, kl_coef: float = 0.1
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
|
@ -23,22 +23,21 @@ class NaiveExperienceMaker(ExperienceMaker):
|
||||
|
||||
# calculate auxiliary tensors
|
||||
attention_mask = None
|
||||
pad_token_id = generate_kwargs.get('pad_token_id', None)
|
||||
pad_token_id = generate_kwargs.get("pad_token_id", None)
|
||||
if pad_token_id is not None:
|
||||
attention_mask = sequences.not_equal(pad_token_id)\
|
||||
.to(dtype=torch.long, device=sequences.device)
|
||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||
|
||||
input_len = input_ids.size(1)
|
||||
eos_token_id = generate_kwargs.get('eos_token_id', None)
|
||||
eos_token_id = generate_kwargs.get("eos_token_id", None)
|
||||
if eos_token_id is None:
|
||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||
else:
|
||||
# left padding may be applied, only mask action
|
||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||
action_mask[:, :input_len] = False
|
||||
action_mask = action_mask[:, 1:]
|
||||
action_mask = action_mask[:, -(sequences.size(1) - input_len):]
|
||||
action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
|
||||
num_actions = action_mask.size(1)
|
||||
|
||||
actor_output = self.actor(sequences, attention_mask)
|
||||
|
@ -1,6 +1,6 @@
|
||||
from .wrapper import convert_to_xformer_model, recover_from_xformer_model
|
||||
|
||||
__all__ = [
|
||||
'convert_to_xformer_model',
|
||||
'recover_from_xformer_model',
|
||||
"convert_to_xformer_model",
|
||||
"recover_from_xformer_model",
|
||||
]
|
||||
|
@ -21,11 +21,12 @@ class XOPTAttention(OPTAttention):
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor]]]:
|
||||
if not self.training:
|
||||
return super().forward(hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask,
|
||||
output_attentions)
|
||||
return super().forward(
|
||||
hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, output_attentions
|
||||
)
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
assert layer_head_mask is None, 'Xformers attention does not support layer_head_mask'
|
||||
assert not output_attentions, 'Xformers attention does not support output_attentions'
|
||||
assert layer_head_mask is None, "Xformers attention does not support layer_head_mask"
|
||||
assert not output_attentions, "Xformers attention does not support output_attentions"
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
@ -69,12 +70,14 @@ class XOPTAttention(OPTAttention):
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
attn_output = xops.memory_efficient_attention(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_bias=xops.LowerTriangularMask(),
|
||||
p=self.dropout if self.training else 0.0,
|
||||
scale=self.scaling)
|
||||
attn_output = xops.memory_efficient_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_bias=xops.LowerTriangularMask(),
|
||||
p=self.dropout if self.training else 0.0,
|
||||
scale=self.scaling,
|
||||
)
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned across GPUs when using tensor-parallelism.
|
||||
|
@ -3,6 +3,13 @@ from .lora import LoRAModule, convert_to_lora_module
|
||||
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
|
||||
|
||||
__all__ = [
|
||||
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'LogSigLoss', 'LogExpLoss',
|
||||
'LoRAModule', 'convert_to_lora_module'
|
||||
"Actor",
|
||||
"Critic",
|
||||
"RewardModel",
|
||||
"PolicyLoss",
|
||||
"ValueLoss",
|
||||
"LogSigLoss",
|
||||
"LogExpLoss",
|
||||
"LoRAModule",
|
||||
"convert_to_lora_module",
|
||||
]
|
||||
|
@ -18,9 +18,10 @@ def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module:
|
||||
Returns:
|
||||
nn.Module: the base model
|
||||
"""
|
||||
assert isinstance(model, (Actor, Critic, RewardModel)), \
|
||||
f'Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first.'
|
||||
assert isinstance(
|
||||
model, (Actor, Critic, RewardModel)
|
||||
), f"Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first."
|
||||
return model.model
|
||||
|
||||
|
||||
__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model']
|
||||
__all__ = ["Actor", "Critic", "RewardModel", "get_base_model"]
|
||||
|
@ -16,18 +16,17 @@ class Actor(LoRAModule):
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
|
||||
def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none") -> None:
|
||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||
self.model = model
|
||||
self.convert_to_lora()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**model_kwargs, # HACK: `generate` method may pass more kwargs
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**model_kwargs, # HACK: `generate` method may pass more kwargs
|
||||
) -> torch.Tensor:
|
||||
"""Returns model output.
|
||||
"""
|
||||
"""Returns model output."""
|
||||
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
|
||||
return output
|
||||
|
@ -23,22 +23,23 @@ class Critic(LoRAModule):
|
||||
model: nn.Module,
|
||||
value_head: nn.Module,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
lora_train_bias: str = "none",
|
||||
use_action_mask: bool = False,
|
||||
) -> None:
|
||||
|
||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||
self.model = model
|
||||
self.value_head = value_head
|
||||
self.use_action_mask = use_action_mask
|
||||
self.convert_to_lora()
|
||||
|
||||
def forward(self,
|
||||
sequences: torch.LongTensor,
|
||||
action_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
sequences: torch.LongTensor,
|
||||
action_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
outputs = self.model(sequences, attention_mask=attention_mask)
|
||||
last_hidden_states = outputs['last_hidden_state']
|
||||
last_hidden_states = outputs["last_hidden_state"]
|
||||
|
||||
values = self.value_head(last_hidden_states).squeeze(-1)
|
||||
|
||||
|
@ -17,11 +17,13 @@ class RewardModel(LoRAModule):
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: nn.Module,
|
||||
value_head: Optional[nn.Module] = None,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
value_head: Optional[nn.Module] = None,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = "none",
|
||||
) -> None:
|
||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||
self.model = model
|
||||
self.convert_to_lora()
|
||||
@ -35,7 +37,7 @@ class RewardModel(LoRAModule):
|
||||
|
||||
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
outputs = self.model(sequences, attention_mask=attention_mask)
|
||||
last_hidden_states = outputs['last_hidden_state']
|
||||
last_hidden_states = outputs["last_hidden_state"]
|
||||
values = self.value_head(last_hidden_states)[:, :-1]
|
||||
value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
|
||||
value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
|
||||
return value
|
||||
|
@ -2,4 +2,4 @@ from .bloom_actor import BLOOMActor
|
||||
from .bloom_critic import BLOOMCritic
|
||||
from .bloom_rm import BLOOMRM
|
||||
|
||||
__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM']
|
||||
__all__ = ["BLOOMActor", "BLOOMCritic", "BLOOMRM"]
|
||||
|
@ -1,7 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
||||
from transformers import BloomConfig, BloomForCausalLM
|
||||
|
||||
from ..base import Actor
|
||||
|
||||
@ -18,12 +17,14 @@ class BLOOMActor(Actor):
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
def __init__(
|
||||
self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = "none",
|
||||
) -> None:
|
||||
if pretrained is not None:
|
||||
model = BloomForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
|
@ -1,8 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
||||
from transformers import BloomConfig, BloomModel
|
||||
|
||||
from ..base import Critic
|
||||
|
||||
@ -18,12 +17,14 @@ class BLOOMCritic(Critic):
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = "none",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if pretrained is not None:
|
||||
model = BloomModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
||||
from transformers import BloomConfig, BloomModel
|
||||
|
||||
from ..base import RewardModel
|
||||
|
||||
@ -17,11 +17,13 @@ class BLOOMRM(RewardModel):
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
def __init__(
|
||||
self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = "none",
|
||||
) -> None:
|
||||
if pretrained is not None:
|
||||
model = BloomModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
|
@ -1,3 +1,3 @@
|
||||
from .chatglm_actor import ChatGLMActor
|
||||
|
||||
__all__ = ['ChatGLMActor']
|
||||
__all__ = ["ChatGLMActor"]
|
||||
|
@ -1,11 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from ..base import Actor
|
||||
from .configuration_chatglm import ChatGLMConfig
|
||||
from .modeling_chatglm import ChatGLMForConditionalGeneration
|
||||
|
||||
from ..base import Actor
|
||||
|
||||
|
||||
class ChatGLMActor(Actor):
|
||||
"""
|
||||
@ -19,10 +17,9 @@ class ChatGLMActor(Actor):
|
||||
do not support lora for now.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[ChatGLMConfig] = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
def __init__(
|
||||
self, pretrained: str = None, config: Optional[ChatGLMConfig] = None, checkpoint: bool = False
|
||||
) -> None:
|
||||
if pretrained is not None:
|
||||
model = ChatGLMForConditionalGeneration.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
@ -31,4 +28,4 @@ class ChatGLMActor(Actor):
|
||||
model = ChatGLMForConditionalGeneration(ChatGLMConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model, lora_rank=0, lora_train_bias='none')
|
||||
super().__init__(model, lora_rank=0, lora_train_bias="none")
|
||||
|
@ -2,15 +2,14 @@
|
||||
This code is copied from https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py
|
||||
"""
|
||||
"""Tokenization classes for ChatGLM."""
|
||||
from typing import List, Optional, Union
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers.utils import logging, PaddingStrategy
|
||||
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
|
||||
from typing import Dict
|
||||
import sentencepiece as spm
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from transformers.tokenization_utils_base import BatchEncoding, EncodedInput
|
||||
from transformers.utils import PaddingStrategy, logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -52,11 +51,11 @@ class TextTokenizer:
|
||||
|
||||
class SPTokenizer:
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
num_image_tokens=20000,
|
||||
max_blank_length=80,
|
||||
byte_fallback=True,
|
||||
self,
|
||||
vocab_file,
|
||||
num_image_tokens=20000,
|
||||
max_blank_length=80,
|
||||
byte_fallback=True,
|
||||
):
|
||||
assert vocab_file is not None
|
||||
self.vocab_file = vocab_file
|
||||
@ -100,9 +99,7 @@ class SPTokenizer:
|
||||
text = self._encode_whitespaces(text, max_len=self.max_blank_length)
|
||||
return text
|
||||
|
||||
def encode(
|
||||
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
|
||||
) -> List[int]:
|
||||
def encode(self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True) -> List[int]:
|
||||
"""
|
||||
@param text: Text to encode.
|
||||
@param linebreak: Whether to encode newline (\n) in text.
|
||||
@ -136,9 +133,7 @@ class SPTokenizer:
|
||||
text = self.postprocess(text)
|
||||
return text
|
||||
|
||||
def tokenize(
|
||||
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
|
||||
) -> List[str]:
|
||||
def tokenize(self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True) -> List[str]:
|
||||
"""
|
||||
@param text: Text to encode.
|
||||
@param linebreak: Whether to encode newline (\n) in text.
|
||||
@ -181,20 +176,20 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
do_lower_case=False,
|
||||
remove_space=False,
|
||||
bos_token='<sop>',
|
||||
eos_token='<eop>',
|
||||
end_token='</s>',
|
||||
mask_token='[MASK]',
|
||||
gmask_token='[gMASK]',
|
||||
padding_side="left",
|
||||
pad_token="<pad>",
|
||||
unk_token="<unk>",
|
||||
num_image_tokens=20000,
|
||||
**kwargs
|
||||
self,
|
||||
vocab_file,
|
||||
do_lower_case=False,
|
||||
remove_space=False,
|
||||
bos_token="<sop>",
|
||||
eos_token="<eop>",
|
||||
end_token="</s>",
|
||||
mask_token="[MASK]",
|
||||
gmask_token="[gMASK]",
|
||||
padding_side="left",
|
||||
pad_token="<pad>",
|
||||
unk_token="<unk>",
|
||||
num_image_tokens=20000,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
do_lower_case=do_lower_case,
|
||||
@ -208,7 +203,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
pad_token=pad_token,
|
||||
unk_token=unk_token,
|
||||
num_image_tokens=num_image_tokens,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.do_lower_case = do_lower_case
|
||||
@ -243,11 +238,11 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
""" Returns vocab size """
|
||||
"""Returns vocab size"""
|
||||
return self.sp_tokenizer.num_tokens
|
||||
|
||||
def get_vocab(self):
|
||||
""" Returns vocab as a dict """
|
||||
"""Returns vocab as a dict"""
|
||||
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
@ -264,7 +259,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
return outputs
|
||||
|
||||
def _tokenize(self, text, **kwargs):
|
||||
""" Returns a tokenized string. """
|
||||
"""Returns a tokenized string."""
|
||||
text = self.preprocess_text(text)
|
||||
|
||||
seq = self.sp_tokenizer.tokenize(text)
|
||||
@ -274,11 +269,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
return self.sp_tokenizer.decode_tokens(tokens)
|
||||
|
||||
def _decode(
|
||||
self,
|
||||
token_ids: Union[int, List[int]],
|
||||
**kwargs
|
||||
) -> str:
|
||||
def _decode(self, token_ids: Union[int, List[int]], **kwargs) -> str:
|
||||
if isinstance(token_ids, int):
|
||||
token_ids = [token_ids]
|
||||
if len(token_ids) == 0:
|
||||
@ -288,7 +279,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
return super()._decode(token_ids, **kwargs)
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
"""Converts a token (str) in an id using the vocab."""
|
||||
return self.sp_tokenizer[token]
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
@ -309,13 +300,11 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
`Tuple(str)`: Paths to the files saved.
|
||||
"""
|
||||
if os.path.isdir(save_directory):
|
||||
vocab_file = os.path.join(
|
||||
save_directory, self.vocab_files_names["vocab_file"]
|
||||
)
|
||||
vocab_file = os.path.join(save_directory, self.vocab_files_names["vocab_file"])
|
||||
else:
|
||||
vocab_file = save_directory
|
||||
|
||||
with open(self.vocab_file, 'rb') as fin:
|
||||
with open(self.vocab_file, "rb") as fin:
|
||||
proto_str = fin.read()
|
||||
|
||||
with open(vocab_file, "wb") as writer:
|
||||
@ -324,7 +313,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
return (vocab_file,)
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||
@ -343,19 +332,19 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
||||
"""
|
||||
gmask_id = self.sp_tokenizer[self.gmask_token]
|
||||
eos_id = self.sp_tokenizer[self.eos_token]
|
||||
self.sp_tokenizer[self.eos_token]
|
||||
token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]]
|
||||
if token_ids_1 is not None:
|
||||
token_ids_0 = token_ids_0 + token_ids_1
|
||||
return token_ids_0
|
||||
|
||||
def _pad(
|
||||
self,
|
||||
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
||||
max_length: Optional[int] = None,
|
||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
self,
|
||||
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
||||
max_length: Optional[int] = None,
|
||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
||||
@ -421,17 +410,23 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
mask_position = required_input.index(mask_token)
|
||||
position_ids[context_length:] = mask_position
|
||||
block_position_ids = np.concatenate(
|
||||
[np.zeros(context_length, dtype=np.int64),
|
||||
np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
|
||||
[
|
||||
np.zeros(context_length, dtype=np.int64),
|
||||
np.arange(1, seq_length - context_length + 1, dtype=np.int64),
|
||||
]
|
||||
)
|
||||
encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
|
||||
|
||||
if needs_to_be_padded:
|
||||
difference = max_length - len(required_input)
|
||||
|
||||
if "attention_mask" in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"],
|
||||
pad_width=[(0, 0), (difference, 0), (difference, 0)],
|
||||
mode='constant', constant_values=True)
|
||||
encoded_inputs["attention_mask"] = np.pad(
|
||||
encoded_inputs["attention_mask"],
|
||||
pad_width=[(0, 0), (difference, 0), (difference, 0)],
|
||||
mode="constant",
|
||||
constant_values=True,
|
||||
)
|
||||
if "token_type_ids" in encoded_inputs:
|
||||
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
|
||||
"token_type_ids"
|
||||
@ -439,8 +434,9 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
if "special_tokens_mask" in encoded_inputs:
|
||||
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
|
||||
if "position_ids" in encoded_inputs:
|
||||
encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"],
|
||||
pad_width=[(0, 0), (difference, 0)])
|
||||
encoded_inputs["position_ids"] = np.pad(
|
||||
encoded_inputs["position_ids"], pad_width=[(0, 0), (difference, 0)]
|
||||
)
|
||||
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
||||
|
||||
return encoded_inputs
|
@ -56,30 +56,29 @@ class ChatGLMConfig(PretrainedConfig):
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
"""
|
||||
```"""
|
||||
model_type = "chatglm"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=130528,
|
||||
hidden_size=4096,
|
||||
num_layers=28,
|
||||
num_attention_heads=32,
|
||||
layernorm_epsilon=1e-5,
|
||||
use_cache=True,
|
||||
bos_token_id=130004,
|
||||
eos_token_id=130005,
|
||||
mask_token_id=130000,
|
||||
gmask_token_id=130001,
|
||||
pad_token_id=3,
|
||||
max_sequence_length=2048,
|
||||
inner_hidden_size=16384,
|
||||
position_encoding_2d=True,
|
||||
quantization_bit=0,
|
||||
pre_seq_len=None,
|
||||
prefix_projection=False,
|
||||
**kwargs
|
||||
self,
|
||||
vocab_size=130528,
|
||||
hidden_size=4096,
|
||||
num_layers=28,
|
||||
num_attention_heads=32,
|
||||
layernorm_epsilon=1e-5,
|
||||
use_cache=True,
|
||||
bos_token_id=130004,
|
||||
eos_token_id=130005,
|
||||
mask_token_id=130000,
|
||||
gmask_token_id=130001,
|
||||
pad_token_id=3,
|
||||
max_sequence_length=2048,
|
||||
inner_hidden_size=16384,
|
||||
position_encoding_2d=True,
|
||||
quantization_bit=0,
|
||||
pre_seq_len=None,
|
||||
prefix_projection=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.vocab_size = vocab_size
|
||||
@ -99,9 +98,4 @@ class ChatGLMConfig(PretrainedConfig):
|
||||
self.pre_seq_len = pre_seq_len
|
||||
self.prefix_projection = prefix_projection
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
**kwargs
|
||||
)
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
@ -4,41 +4,40 @@ This code is copied from https://huggingface.co/THUDM/chatglm-6b/resolve/main/mo
|
||||
|
||||
""" PyTorch ChatGLM model. """
|
||||
|
||||
import math
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, LayerNorm
|
||||
from torch.nn.utils import skip_init
|
||||
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
||||
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithPast,
|
||||
)
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
)
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
)
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import logging
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
|
||||
|
||||
from .configuration_chatglm import ChatGLMConfig
|
||||
|
||||
# flags required to enable jit fusion kernels
|
||||
|
||||
if sys.platform != 'darwin':
|
||||
if sys.platform != "darwin":
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
torch._C._jit_override_can_fuse_on_cpu(True)
|
||||
@ -93,8 +92,8 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if any(
|
||||
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
||||
for n in name
|
||||
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
||||
for n in name
|
||||
):
|
||||
logger.info(f"Skipping {'/'.join(name)}")
|
||||
continue
|
||||
@ -127,7 +126,7 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
|
||||
array = np.transpose(array)
|
||||
try:
|
||||
assert (
|
||||
pointer.shape == array.shape
|
||||
pointer.shape == array.shape
|
||||
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
@ -153,7 +152,7 @@ class PrefixEncoder(torch.nn.Module):
|
||||
self.trans = torch.nn.Sequential(
|
||||
torch.nn.Linear(config.hidden_size, config.hidden_size),
|
||||
torch.nn.Tanh(),
|
||||
torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
|
||||
torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2),
|
||||
)
|
||||
else:
|
||||
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
|
||||
@ -170,8 +169,7 @@ class PrefixEncoder(torch.nn.Module):
|
||||
@torch.jit.script
|
||||
def gelu_impl(x):
|
||||
"""OpenAI's gelu implementation."""
|
||||
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
|
||||
(1.0 + 0.044715 * x * x)))
|
||||
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
|
||||
|
||||
|
||||
def gelu(x):
|
||||
@ -181,21 +179,22 @@ def gelu(x):
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
|
||||
super().__init__()
|
||||
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||
inv_freq = inv_freq.half()
|
||||
self.learnable = learnable
|
||||
if learnable:
|
||||
self.inv_freq = torch.nn.Parameter(inv_freq)
|
||||
self.max_seq_len_cached = None
|
||||
else:
|
||||
self.register_buffer('inv_freq', inv_freq)
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
self.max_seq_len_cached = None
|
||||
self.cos_cached = None
|
||||
self.sin_cached = None
|
||||
self.precision = precision
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||
error_msgs):
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
pass
|
||||
|
||||
def forward(self, x, seq_dim=1, seq_len=None):
|
||||
@ -204,7 +203,7 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
|
||||
self.max_seq_len_cached = None if self.learnable else seq_len
|
||||
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
||||
if self.precision == torch.bfloat16:
|
||||
@ -230,30 +229,31 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
||||
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
|
||||
# position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
|
||||
cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \
|
||||
F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)
|
||||
cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), F.embedding(
|
||||
position_id, sin.squeeze(1)
|
||||
).unsqueeze(2)
|
||||
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
||||
return q, k
|
||||
|
||||
|
||||
def attention_fn(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attention_mask,
|
||||
hidden_size_per_partition,
|
||||
layer_id,
|
||||
layer_past=None,
|
||||
scaling_attention_score=True,
|
||||
use_cache=False,
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attention_mask,
|
||||
hidden_size_per_partition,
|
||||
layer_id,
|
||||
layer_past=None,
|
||||
scaling_attention_score=True,
|
||||
use_cache=False,
|
||||
):
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past[0], layer_past[1]
|
||||
@ -285,7 +285,9 @@ def attention_fn(
|
||||
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
|
||||
|
||||
matmul_result = torch.zeros(
|
||||
1, 1, 1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
dtype=query_layer.dtype,
|
||||
device=query_layer.device,
|
||||
)
|
||||
@ -355,9 +357,17 @@ def default_init(cls, *args, **kwargs):
|
||||
|
||||
|
||||
class SelfAttention(torch.nn.Module):
|
||||
def __init__(self, hidden_size, num_attention_heads,
|
||||
layer_id, hidden_size_per_attention_head=None, bias=True,
|
||||
params_dtype=torch.float, position_encoding_2d=True, empty_init=True):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
layer_id,
|
||||
hidden_size_per_attention_head=None,
|
||||
bias=True,
|
||||
params_dtype=torch.float,
|
||||
position_encoding_2d=True,
|
||||
empty_init=True,
|
||||
):
|
||||
if empty_init:
|
||||
init_method = skip_init
|
||||
else:
|
||||
@ -410,8 +420,7 @@ class SelfAttention(torch.nn.Module):
|
||||
attention_scores.masked_fill_(attention_mask, -10000.0)
|
||||
return attention_scores
|
||||
|
||||
def split_tensor_along_last_dim(self, tensor, num_partitions,
|
||||
contiguous_split_chunks=False):
|
||||
def split_tensor_along_last_dim(self, tensor, num_partitions, contiguous_split_chunks=False):
|
||||
"""Split a tensor along its last dimension.
|
||||
Arguments:
|
||||
tensor: input tensor.
|
||||
@ -431,14 +440,14 @@ class SelfAttention(torch.nn.Module):
|
||||
return tensor_list
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_id,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_id,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
"""
|
||||
hidden_states: [seq_len, batch, hidden_size]
|
||||
@ -462,8 +471,10 @@ class SelfAttention(torch.nn.Module):
|
||||
q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))
|
||||
k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
|
||||
cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)
|
||||
position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \
|
||||
position_ids[:, 1, :].transpose(0, 1).contiguous()
|
||||
position_ids, block_position_ids = (
|
||||
position_ids[:, 0, :].transpose(0, 1).contiguous(),
|
||||
position_ids[:, 1, :].transpose(0, 1).contiguous(),
|
||||
)
|
||||
q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids)
|
||||
q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids)
|
||||
query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
|
||||
@ -484,7 +495,7 @@ class SelfAttention(torch.nn.Module):
|
||||
hidden_size_per_partition=self.hidden_size_per_partition,
|
||||
layer_id=layer_id,
|
||||
layer_past=layer_past,
|
||||
use_cache=use_cache
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
output = self.dense(context_layer)
|
||||
@ -509,8 +520,16 @@ class GEGLU(torch.nn.Module):
|
||||
|
||||
|
||||
class GLU(torch.nn.Module):
|
||||
def __init__(self, hidden_size, inner_hidden_size=None,
|
||||
layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
inner_hidden_size=None,
|
||||
layer_id=None,
|
||||
bias=True,
|
||||
activation_func=gelu,
|
||||
params_dtype=torch.float,
|
||||
empty_init=True,
|
||||
):
|
||||
super(GLU, self).__init__()
|
||||
if empty_init:
|
||||
init_method = skip_init
|
||||
@ -557,19 +576,19 @@ class GLU(torch.nn.Module):
|
||||
|
||||
class GLMBlock(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
layernorm_epsilon,
|
||||
layer_id,
|
||||
inner_hidden_size=None,
|
||||
hidden_size_per_attention_head=None,
|
||||
layernorm=LayerNorm,
|
||||
use_bias=True,
|
||||
params_dtype=torch.float,
|
||||
num_layers=28,
|
||||
position_encoding_2d=True,
|
||||
empty_init=True
|
||||
self,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
layernorm_epsilon,
|
||||
layer_id,
|
||||
inner_hidden_size=None,
|
||||
hidden_size_per_attention_head=None,
|
||||
layernorm=LayerNorm,
|
||||
use_bias=True,
|
||||
params_dtype=torch.float,
|
||||
num_layers=28,
|
||||
position_encoding_2d=True,
|
||||
empty_init=True,
|
||||
):
|
||||
super(GLMBlock, self).__init__()
|
||||
# Set output layer initialization if not provided.
|
||||
@ -590,7 +609,7 @@ class GLMBlock(torch.nn.Module):
|
||||
bias=use_bias,
|
||||
params_dtype=params_dtype,
|
||||
position_encoding_2d=self.position_encoding_2d,
|
||||
empty_init=empty_init
|
||||
empty_init=empty_init,
|
||||
)
|
||||
|
||||
# Layernorm on the input data.
|
||||
@ -605,18 +624,18 @@ class GLMBlock(torch.nn.Module):
|
||||
bias=use_bias,
|
||||
layer_id=layer_id,
|
||||
params_dtype=params_dtype,
|
||||
empty_init=empty_init
|
||||
empty_init=empty_init,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_id,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_id,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
"""
|
||||
hidden_states: [seq_len, batch, hidden_size]
|
||||
@ -635,7 +654,7 @@ class GLMBlock(torch.nn.Module):
|
||||
layer_id=layer_id,
|
||||
layer_past=layer_past,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
attention_output = attention_outputs[0]
|
||||
@ -702,10 +721,15 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
||||
for i, context_length in enumerate(context_lengths):
|
||||
position_ids[i, context_length:] = mask_positions[i]
|
||||
block_position_ids = [torch.cat((
|
||||
torch.zeros(context_length, dtype=torch.long, device=device),
|
||||
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
||||
)) for context_length in context_lengths]
|
||||
block_position_ids = [
|
||||
torch.cat(
|
||||
(
|
||||
torch.zeros(context_length, dtype=torch.long, device=device),
|
||||
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1,
|
||||
)
|
||||
)
|
||||
for context_length in context_lengths
|
||||
]
|
||||
block_position_ids = torch.stack(block_position_ids, dim=0)
|
||||
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
||||
else:
|
||||
@ -823,9 +847,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
self.prefix_projection = config.prefix_projection
|
||||
|
||||
self.word_embeddings = init_method(
|
||||
torch.nn.Embedding,
|
||||
num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
|
||||
dtype=self.params_dtype
|
||||
torch.nn.Embedding, num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, dtype=self.params_dtype
|
||||
)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@ -841,12 +863,10 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
use_bias=True,
|
||||
params_dtype=self.params_dtype,
|
||||
position_encoding_2d=self.position_encoding_2d,
|
||||
empty_init=empty_init
|
||||
empty_init=empty_init,
|
||||
)
|
||||
|
||||
self.layers = torch.nn.ModuleList(
|
||||
[get_layer(layer_id) for layer_id in range(self.num_layers)]
|
||||
)
|
||||
self.layers = torch.nn.ModuleList([get_layer(layer_id) for layer_id in range(self.num_layers)])
|
||||
|
||||
# Final layer norm before output.
|
||||
self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon)
|
||||
@ -876,7 +896,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
self.pre_seq_len,
|
||||
self.num_layers * 2,
|
||||
self.num_attention_heads,
|
||||
self.hidden_size // self.num_attention_heads
|
||||
self.hidden_size // self.num_attention_heads,
|
||||
)
|
||||
# seq_len, b, nh, hidden_size
|
||||
past_key_values = self.dropout(past_key_values)
|
||||
@ -891,18 +911,17 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
@ -931,17 +950,14 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
|
||||
if past_key_values is None:
|
||||
if self.pre_seq_len is not None:
|
||||
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device,
|
||||
dtype=inputs_embeds.dtype)
|
||||
past_key_values = self.get_prompt(
|
||||
batch_size=input_ids.shape[0], device=input_ids.device, dtype=inputs_embeds.dtype
|
||||
)
|
||||
else:
|
||||
past_key_values = tuple([None] * len(self.layers))
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = self.get_masks(
|
||||
input_ids,
|
||||
device=input_ids.device
|
||||
)
|
||||
|
||||
attention_mask = self.get_masks(input_ids, device=input_ids.device)
|
||||
|
||||
if position_ids is None:
|
||||
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
||||
@ -955,15 +971,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
use_gmasks.append(use_gmask)
|
||||
|
||||
position_ids = self.get_position_ids(
|
||||
input_ids,
|
||||
mask_positions=mask_positions,
|
||||
device=input_ids.device,
|
||||
use_gmasks=use_gmasks
|
||||
input_ids, mask_positions=mask_positions, device=input_ids.device, use_gmasks=use_gmasks
|
||||
)
|
||||
|
||||
if self.pre_seq_len is not None and attention_mask is not None:
|
||||
prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
|
||||
attention_mask.device)
|
||||
attention_mask.device
|
||||
)
|
||||
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
|
||||
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
||||
|
||||
@ -980,7 +994,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
layer_past = past_key_values[i]
|
||||
@ -994,7 +1007,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
torch.tensor(i),
|
||||
layer_past,
|
||||
use_cache,
|
||||
output_attentions
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_ret = layer(
|
||||
@ -1004,7 +1017,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
layer_id=torch.tensor(i),
|
||||
layer_past=layer_past,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_ret[0]
|
||||
@ -1049,13 +1062,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
|
||||
self.transformer = ChatGLMModel(config, empty_init=empty_init)
|
||||
|
||||
self.lm_head = init_method(
|
||||
nn.Linear,
|
||||
config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
dtype=torch.half
|
||||
)
|
||||
self.lm_head = init_method(nn.Linear, config.hidden_size, config.vocab_size, bias=False, dtype=torch.half)
|
||||
|
||||
self.config = config
|
||||
|
||||
@ -1087,32 +1094,29 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
attention_mask = model_kwargs["attention_mask"]
|
||||
if attention_mask is not None and attention_mask.dtype == torch.bool:
|
||||
attention_mask = torch.cat(
|
||||
[attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3)
|
||||
[attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3
|
||||
)
|
||||
new_attention_mask = attention_mask[:, :, -1:].clone()
|
||||
new_attention_mask[..., -1] = False
|
||||
model_kwargs["attention_mask"] = torch.cat(
|
||||
[attention_mask, new_attention_mask], dim=2
|
||||
)
|
||||
model_kwargs["attention_mask"] = torch.cat([attention_mask, new_attention_mask], dim=2)
|
||||
|
||||
# update position ids
|
||||
if "position_ids" in model_kwargs:
|
||||
position_ids = model_kwargs["position_ids"]
|
||||
new_position_id = position_ids[..., -1:].clone()
|
||||
new_position_id[:, 1, :] += 1
|
||||
model_kwargs["position_ids"] = torch.cat(
|
||||
[position_ids, new_position_id], dim=-1
|
||||
)
|
||||
model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
past: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
past: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
||||
@ -1137,11 +1141,17 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
|
||||
if self.position_encoding_2d:
|
||||
position_ids = torch.tensor(
|
||||
[[mask_position, seq_length - context_length] for mask_position, context_length in
|
||||
zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
|
||||
[
|
||||
[mask_position, seq_length - context_length]
|
||||
for mask_position, context_length in zip(mask_positions, context_lengths)
|
||||
],
|
||||
dtype=torch.long,
|
||||
device=input_ids.device,
|
||||
).unsqueeze(-1)
|
||||
else:
|
||||
position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
|
||||
device=input_ids.device).unsqueeze(-1)
|
||||
position_ids = torch.tensor(
|
||||
[mask_position for mask_position in mask_positions], dtype=torch.long, device=input_ids.device
|
||||
).unsqueeze(-1)
|
||||
|
||||
if past is None:
|
||||
past = past_key_values
|
||||
@ -1149,44 +1159,38 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
"input_ids": last_token,
|
||||
"past_key_values": past,
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
else:
|
||||
if attention_mask is not None and attention_mask.dtype != torch.bool:
|
||||
logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool")
|
||||
attention_mask = None
|
||||
if attention_mask is None:
|
||||
attention_mask = self.get_masks(
|
||||
input_ids,
|
||||
device=input_ids.device
|
||||
)
|
||||
attention_mask = self.get_masks(input_ids, device=input_ids.device)
|
||||
if position_ids is None:
|
||||
position_ids = self.get_position_ids(
|
||||
input_ids,
|
||||
device=input_ids.device,
|
||||
mask_positions=mask_positions,
|
||||
use_gmasks=use_gmasks
|
||||
input_ids, device=input_ids.device, mask_positions=mask_positions, use_gmasks=use_gmasks
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
@ -1235,7 +1239,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(
|
||||
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
|
||||
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
||||
"""
|
||||
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
||||
@ -1268,15 +1272,33 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
return response
|
||||
|
||||
@torch.no_grad()
|
||||
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
||||
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
||||
def chat(
|
||||
self,
|
||||
tokenizer,
|
||||
query: str,
|
||||
history: List[Tuple[str, str]] = None,
|
||||
max_length: int = 2048,
|
||||
num_beams=1,
|
||||
do_sample=True,
|
||||
top_p=0.7,
|
||||
temperature=0.95,
|
||||
logits_processor=None,
|
||||
**kwargs,
|
||||
):
|
||||
if history is None:
|
||||
history = []
|
||||
if logits_processor is None:
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InvalidScoreLogitsProcessor())
|
||||
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
||||
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
||||
gen_kwargs = {
|
||||
"max_length": max_length,
|
||||
"num_beams": num_beams,
|
||||
"do_sample": do_sample,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"logits_processor": logits_processor,
|
||||
**kwargs,
|
||||
}
|
||||
if not history:
|
||||
prompt = query
|
||||
else:
|
||||
@ -1287,22 +1309,38 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
inputs = tokenizer([prompt], return_tensors="pt")
|
||||
inputs = inputs.to(self.device)
|
||||
outputs = self.generate(**inputs, **gen_kwargs)
|
||||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
||||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
|
||||
response = tokenizer.decode(outputs)
|
||||
response = self.process_response(response)
|
||||
history = history + [(query, response)]
|
||||
return response, history
|
||||
|
||||
@torch.no_grad()
|
||||
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
|
||||
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
||||
def stream_chat(
|
||||
self,
|
||||
tokenizer,
|
||||
query: str,
|
||||
history: List[Tuple[str, str]] = None,
|
||||
max_length: int = 2048,
|
||||
do_sample=True,
|
||||
top_p=0.7,
|
||||
temperature=0.95,
|
||||
logits_processor=None,
|
||||
**kwargs,
|
||||
):
|
||||
if history is None:
|
||||
history = []
|
||||
if logits_processor is None:
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InvalidScoreLogitsProcessor())
|
||||
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
||||
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
||||
gen_kwargs = {
|
||||
"max_length": max_length,
|
||||
"do_sample": do_sample,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"logits_processor": logits_processor,
|
||||
**kwargs,
|
||||
}
|
||||
if not history:
|
||||
prompt = query
|
||||
else:
|
||||
@ -1313,7 +1351,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
inputs = tokenizer([prompt], return_tensors="pt")
|
||||
inputs = inputs.to(self.device)
|
||||
for outputs in self.stream_generate(**inputs, **gen_kwargs):
|
||||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
||||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
|
||||
response = tokenizer.decode(outputs)
|
||||
response = self.process_response(response)
|
||||
new_history = history + [(query, response)]
|
||||
@ -1321,13 +1359,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
|
||||
@torch.no_grad()
|
||||
def stream_generate(
|
||||
self,
|
||||
input_ids,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
**kwargs,
|
||||
self,
|
||||
input_ids,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
||||
|
||||
|
@ -16,9 +16,9 @@ except ImportError:
|
||||
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
|
||||
|
||||
|
||||
def _prepare_logits_processor(top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None) -> LogitsProcessorList:
|
||||
def _prepare_logits_processor(
|
||||
top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
|
||||
) -> LogitsProcessorList:
|
||||
processor_list = LogitsProcessorList()
|
||||
if temperature is not None and temperature != 1.0:
|
||||
processor_list.append(TemperatureLogitsWarper(temperature))
|
||||
@ -37,18 +37,20 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
|
||||
return unfinished_sequences.max() == 0
|
||||
|
||||
|
||||
def _sample(model: Actor,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int,
|
||||
early_stopping: bool = False,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
||||
**model_kwargs) -> torch.Tensor:
|
||||
def _sample(
|
||||
model: Actor,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int,
|
||||
early_stopping: bool = False,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
||||
**model_kwargs,
|
||||
) -> torch.Tensor:
|
||||
if input_ids.size(1) >= max_length:
|
||||
return input_ids
|
||||
|
||||
@ -56,11 +58,12 @@ def _sample(model: Actor,
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
|
||||
for _ in range(input_ids.size(1), max_length):
|
||||
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) \
|
||||
if prepare_inputs_fn is not None else {'input_ids': input_ids}
|
||||
model_inputs = (
|
||||
prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids}
|
||||
)
|
||||
outputs = model(**model_inputs)
|
||||
|
||||
next_token_logits = outputs['logits'][:, -1, :]
|
||||
next_token_logits = outputs["logits"][:, -1, :]
|
||||
# pre-process distribution
|
||||
next_token_logits = logits_processor(input_ids, next_token_logits)
|
||||
# sample
|
||||
@ -90,20 +93,22 @@ def _sample(model: Actor,
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(model: Actor,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int,
|
||||
num_beams: int = 1,
|
||||
do_sample: bool = True,
|
||||
early_stopping: bool = False,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
||||
**model_kwargs) -> torch.Tensor:
|
||||
def generate(
|
||||
model: Actor,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int,
|
||||
num_beams: int = 1,
|
||||
do_sample: bool = True,
|
||||
early_stopping: bool = False,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
||||
**model_kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Generate token sequence. The returned sequence is input_ids + generated_tokens.
|
||||
|
||||
Args:
|
||||
@ -121,26 +126,28 @@ def generate(model: Actor,
|
||||
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
|
||||
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
|
||||
"""
|
||||
is_greedy_gen_mode = ((num_beams == 1) and do_sample is False)
|
||||
is_sample_gen_mode = ((num_beams == 1) and do_sample is True)
|
||||
is_beam_gen_mode = ((num_beams > 1) and do_sample is False)
|
||||
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
|
||||
is_sample_gen_mode = (num_beams == 1) and do_sample is True
|
||||
is_beam_gen_mode = (num_beams > 1) and do_sample is False
|
||||
if is_greedy_gen_mode:
|
||||
# run greedy search
|
||||
raise NotImplementedError
|
||||
elif is_sample_gen_mode:
|
||||
# run sample
|
||||
return _sample(model,
|
||||
input_ids,
|
||||
max_length,
|
||||
early_stopping=early_stopping,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
prepare_inputs_fn=prepare_inputs_fn,
|
||||
update_model_kwargs_fn=update_model_kwargs_fn,
|
||||
**model_kwargs)
|
||||
return _sample(
|
||||
model,
|
||||
input_ids,
|
||||
max_length,
|
||||
early_stopping=early_stopping,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
prepare_inputs_fn=prepare_inputs_fn,
|
||||
update_model_kwargs_fn=update_model_kwargs_fn,
|
||||
**model_kwargs,
|
||||
)
|
||||
elif is_beam_gen_mode:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
|
@ -2,4 +2,4 @@ from .gpt_actor import GPTActor
|
||||
from .gpt_critic import GPTCritic
|
||||
from .gpt_rm import GPTRM
|
||||
|
||||
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM']
|
||||
__all__ = ["GPTActor", "GPTCritic", "GPTRM"]
|
||||
|
@ -18,13 +18,15 @@ class GPTActor(Actor):
|
||||
lora_train_bias (str): Bias training strategy for the LoRa layer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = "none",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if pretrained is not None:
|
||||
model = GPT2LMHeadModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
|
@ -18,12 +18,14 @@ class GPTCritic(Critic):
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = "none",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if pretrained is not None:
|
||||
model = GPT2Model.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
|
@ -18,11 +18,13 @@ class GPTRM(RewardModel):
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
def __init__(
|
||||
self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = "none",
|
||||
) -> None:
|
||||
if pretrained is not None:
|
||||
model = GPT2Model.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
|
@ -2,4 +2,4 @@ from .llama_actor import LlamaActor
|
||||
from .llama_critic import LlamaCritic
|
||||
from .llama_rm import LlamaRM
|
||||
|
||||
__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM']
|
||||
__all__ = ["LlamaActor", "LlamaCritic", "LlamaRM"]
|
||||
|
@ -1,7 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
|
||||
from ..base import Actor
|
||||
|
||||
@ -18,13 +17,14 @@ class LlamaActor(Actor):
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[LlamaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[LlamaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = "none",
|
||||
) -> None:
|
||||
if pretrained is not None:
|
||||
model = LlamaForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
|
@ -17,13 +17,14 @@ class LlamaCritic(Critic):
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[LlamaConfig] = None,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[LlamaConfig] = None,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = "none",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if pretrained is not None:
|
||||
model = LlamaModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
|
||||
from transformers import LlamaConfig, LlamaModel
|
||||
|
||||
from ..base import RewardModel
|
||||
|
||||
@ -17,12 +17,13 @@ class LlamaRM(RewardModel):
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[LlamaConfig] = None,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[LlamaConfig] = None,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = "none",
|
||||
) -> None:
|
||||
if pretrained is not None:
|
||||
model = LlamaModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
|
@ -8,8 +8,7 @@ import torch.nn.functional as F
|
||||
|
||||
|
||||
class LoraLinear(lora.LoRALayer, nn.Module):
|
||||
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.
|
||||
"""
|
||||
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -17,16 +16,14 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
||||
bias: Optional[nn.Parameter],
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.,
|
||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
lora_dropout: float = 0.0,
|
||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
merge_weights: bool = True,
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
lora.LoRALayer.__init__(self,
|
||||
r=r,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
merge_weights=merge_weights)
|
||||
lora.LoRALayer.__init__(
|
||||
self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
|
||||
)
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
|
||||
@ -47,13 +44,12 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
||||
self.weight.data = self.weight.data.T
|
||||
|
||||
def reset_parameters(self):
|
||||
if hasattr(self, 'lora_A'):
|
||||
if hasattr(self, "lora_A"):
|
||||
# Initialize A with the default values for nn.Linear and set B to zero.
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
|
||||
def T(w):
|
||||
return w.T if self.fan_in_fan_out else w
|
||||
|
||||
@ -71,7 +67,6 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
||||
self.merged = False
|
||||
|
||||
def eval(self):
|
||||
|
||||
def T(w):
|
||||
return w.T if self.fan_in_fan_out else w
|
||||
|
||||
@ -80,12 +75,11 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
||||
# Merge the weights and mark it
|
||||
if self.r > 0:
|
||||
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
||||
delattr(self, 'lora_A')
|
||||
delattr(self, 'lora_B')
|
||||
delattr(self, "lora_A")
|
||||
delattr(self, "lora_B")
|
||||
self.merged = True
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
|
||||
def T(w):
|
||||
return w.T if self.fan_in_fan_out else w
|
||||
|
||||
@ -99,7 +93,9 @@ class LoraLinear(lora.LoRALayer, nn.Module):
|
||||
|
||||
|
||||
def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
|
||||
assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
|
||||
assert (
|
||||
lora_rank <= linear.in_features
|
||||
), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
|
||||
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
|
||||
return lora_linear
|
||||
|
||||
@ -112,7 +108,7 @@ def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
|
||||
_convert_to_lora_recursively(child, lora_rank)
|
||||
|
||||
|
||||
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module:
|
||||
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = "none") -> nn.Module:
|
||||
"""Convert a torch.nn.Module to a LoRA module.
|
||||
|
||||
Args:
|
||||
@ -140,7 +136,7 @@ class LoRAModule(nn.Module):
|
||||
Defaults to 'none'.
|
||||
"""
|
||||
|
||||
def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
|
||||
def __init__(self, lora_rank: int = 0, lora_train_bias: str = "none") -> None:
|
||||
super().__init__()
|
||||
self.lora_rank = lora_rank
|
||||
self.lora_train_bias = lora_train_bias
|
||||
|
@ -31,11 +31,13 @@ class PolicyLoss(nn.Module):
|
||||
super().__init__()
|
||||
self.clip_eps = clip_eps
|
||||
|
||||
def forward(self,
|
||||
log_probs: torch.Tensor,
|
||||
old_log_probs: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
log_probs: torch.Tensor,
|
||||
old_log_probs: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
ratio = (log_probs - old_log_probs).exp()
|
||||
surr1 = ratio * advantages
|
||||
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||
@ -55,14 +57,16 @@ class ValueLoss(nn.Module):
|
||||
super().__init__()
|
||||
self.clip_eps = clip_eps
|
||||
|
||||
def forward(self,
|
||||
values: torch.Tensor,
|
||||
old_values: torch.Tensor,
|
||||
reward: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
values: torch.Tensor,
|
||||
old_values: torch.Tensor,
|
||||
reward: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
|
||||
surr1 = (values_clipped - reward)**2
|
||||
surr2 = (values - reward)**2
|
||||
surr1 = (values_clipped - reward) ** 2
|
||||
surr2 = (values - reward) ** 2
|
||||
loss = torch.max(surr1, surr2)
|
||||
loss = loss.mean()
|
||||
return 0.5 * loss
|
||||
|
@ -2,4 +2,4 @@ from .opt_actor import OPTActor
|
||||
from .opt_critic import OPTCritic
|
||||
from .opt_rm import OPTRM
|
||||
|
||||
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM']
|
||||
__all__ = ["OPTActor", "OPTCritic", "OPTRM"]
|
||||
|
@ -18,12 +18,14 @@ class OPTActor(Actor):
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
def __init__(
|
||||
self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = "none",
|
||||
) -> None:
|
||||
if pretrained is not None:
|
||||
model = OPTForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
|
@ -18,12 +18,14 @@ class OPTCritic(Critic):
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = "none",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if pretrained is not None:
|
||||
model = OPTModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
|
@ -17,11 +17,13 @@ class OPTRM(RewardModel):
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
def __init__(
|
||||
self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = "none",
|
||||
) -> None:
|
||||
if pretrained is not None:
|
||||
model = OPTModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
|
@ -4,9 +4,9 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def _compute_approx_kl(log_probs: torch.Tensor,
|
||||
log_probs_base: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def _compute_approx_kl(
|
||||
log_probs: torch.Tensor, log_probs_base: torch.Tensor, action_mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the approximate KL divergence between two distributions.
|
||||
Schulman blog: http://joschu.net/blog/kl-approx.html
|
||||
@ -26,11 +26,13 @@ def _compute_approx_kl(log_probs: torch.Tensor,
|
||||
return approx_kl
|
||||
|
||||
|
||||
def compute_reward(r: Union[torch.Tensor, float],
|
||||
kl_coef: float,
|
||||
log_probs: torch.Tensor,
|
||||
log_probs_base: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def compute_reward(
|
||||
r: Union[torch.Tensor, float],
|
||||
kl_coef: float,
|
||||
log_probs: torch.Tensor,
|
||||
log_probs_base: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if kl_coef <= 0.0:
|
||||
return r
|
||||
kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
|
||||
@ -55,7 +57,7 @@ def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num
|
||||
Returns:
|
||||
torch.Tensor: Action log probs.
|
||||
"""
|
||||
logits = output['logits']
|
||||
logits = output["logits"]
|
||||
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
return log_probs[:, -num_actions:]
|
||||
|
||||
|
@ -2,6 +2,6 @@ from .llama_gptq import load_quant as llama_load_quant
|
||||
from .utils import low_resource_init
|
||||
|
||||
__all__ = [
|
||||
'llama_load_quant',
|
||||
'low_resource_init',
|
||||
"llama_load_quant",
|
||||
"low_resource_init",
|
||||
]
|
||||
|
@ -1,5 +1,5 @@
|
||||
from .loader import load_quant
|
||||
|
||||
__all__ = [
|
||||
'load_quant',
|
||||
"load_quant",
|
||||
]
|
||||
|
@ -11,14 +11,15 @@ def load_quant(model: nn.Module, checkpoint: str, wbits: int, groupsize: int):
|
||||
|
||||
# ignore lm head
|
||||
layers = find_layers(model)
|
||||
for name in ['lm_head']:
|
||||
for name in ["lm_head"]:
|
||||
if name in layers:
|
||||
del layers[name]
|
||||
|
||||
make_quant(model, layers, wbits, groupsize)
|
||||
|
||||
if checkpoint.endswith('.safetensors'):
|
||||
if checkpoint.endswith(".safetensors"):
|
||||
from safetensors.torch import load_file as safe_load
|
||||
|
||||
model.load_state_dict(safe_load(checkpoint))
|
||||
else:
|
||||
model.load_state_dict(torch.load(checkpoint))
|
||||
|
@ -1,13 +1,12 @@
|
||||
# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
|
||||
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""):
|
||||
if type(module) in layers:
|
||||
return {name: module}
|
||||
res = {}
|
||||
for name1, child in module.named_children():
|
||||
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
|
||||
res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1))
|
||||
return res
|
||||
|
@ -13,14 +13,13 @@ def quantize(x, scale, zero, maxq):
|
||||
|
||||
|
||||
class Quantizer(nn.Module):
|
||||
|
||||
def __init__(self, shape=1):
|
||||
super(Quantizer, self).__init__()
|
||||
self.register_buffer('maxq', torch.tensor(0))
|
||||
self.register_buffer('scale', torch.zeros(shape))
|
||||
self.register_buffer('zero', torch.zeros(shape))
|
||||
self.register_buffer("maxq", torch.tensor(0))
|
||||
self.register_buffer("scale", torch.zeros(shape))
|
||||
self.register_buffer("zero", torch.zeros(shape))
|
||||
|
||||
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8):
|
||||
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=0.8):
|
||||
self.maxq = torch.tensor(2**bits - 1)
|
||||
self.perchannel = perchannel
|
||||
self.sym = sym
|
||||
@ -68,7 +67,7 @@ class Quantizer(nn.Module):
|
||||
self.zero = torch.round(-xmin / self.scale)
|
||||
|
||||
if self.mse:
|
||||
best = torch.full([x.shape[0]], float('inf'), device=dev)
|
||||
best = torch.full([x.shape[0]], float("inf"), device=dev)
|
||||
for i in range(int(self.maxshrink * self.grid)):
|
||||
p = 1 - i / self.grid
|
||||
xmin1 = p * xmin
|
||||
@ -123,13 +122,12 @@ class Quantizer(nn.Module):
|
||||
try:
|
||||
import quant_cuda
|
||||
except:
|
||||
print('CUDA extension not installed.')
|
||||
print("CUDA extension not installed.")
|
||||
|
||||
# Assumes layer is perfectly divisible into 256 * 256 blocks
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
|
||||
def __init__(self, bits, groupsize, infeatures, outfeatures):
|
||||
super().__init__()
|
||||
if bits not in [2, 3, 4, 8]:
|
||||
@ -142,11 +140,11 @@ class QuantLinear(nn.Module):
|
||||
groupsize = groupsize if groupsize != -1 else infeatures
|
||||
self.groupsize = groupsize
|
||||
self.register_buffer(
|
||||
'qzeros', torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)),
|
||||
dtype=torch.int))
|
||||
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
|
||||
self.register_buffer('bias', torch.zeros(outfeatures))
|
||||
self.register_buffer('qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
|
||||
"qzeros", torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)), dtype=torch.int)
|
||||
)
|
||||
self.register_buffer("scales", torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
|
||||
self.register_buffer("bias", torch.zeros(outfeatures))
|
||||
self.register_buffer("qweight", torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
|
||||
self._initialized_quant_state = False
|
||||
|
||||
def pack(self, linear, scales, zeros):
|
||||
@ -161,8 +159,10 @@ class QuantLinear(nn.Module):
|
||||
for idx in range(self.infeatures):
|
||||
g_idx = idx // self.groupsize
|
||||
intweight.append(
|
||||
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:,
|
||||
None])
|
||||
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[
|
||||
:, None
|
||||
]
|
||||
)
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(np.uint32)
|
||||
@ -271,13 +271,13 @@ class QuantLinear(nn.Module):
|
||||
return y.reshape(outshape)
|
||||
|
||||
|
||||
def make_quant(module, names, bits, groupsize, name=''):
|
||||
def make_quant(module, names, bits, groupsize, name=""):
|
||||
if isinstance(module, QuantLinear):
|
||||
return
|
||||
for attr in dir(module):
|
||||
tmp = getattr(module, attr)
|
||||
name1 = name + '.' + attr if name != '' else attr
|
||||
name1 = name + "." + attr if name != "" else attr
|
||||
if name1 in names:
|
||||
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features))
|
||||
for name1, child in module.named_children():
|
||||
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
|
||||
make_quant(child, names, bits, groupsize, name + "." + name1 if name != "" else name1)
|
||||
|
@ -9,8 +9,7 @@ def _noop(*args, **kwargs):
|
||||
|
||||
@contextmanager
|
||||
def low_resource_init():
|
||||
"""This context manager disables weight initialization and sets the default float dtype to half.
|
||||
"""
|
||||
"""This context manager disables weight initialization and sets the default float dtype to half."""
|
||||
old_kaiming_uniform_ = torch.nn.init.kaiming_uniform_
|
||||
old_uniform_ = torch.nn.init.uniform_
|
||||
old_normal_ = torch.nn.init.normal_
|
||||
|
@ -5,7 +5,7 @@ from coati.experience_maker import Experience
|
||||
|
||||
class TrainerCallback(ABC):
|
||||
"""
|
||||
Base callback class. It defines the interface for callbacks.
|
||||
Base callback class. It defines the interface for callbacks.
|
||||
"""
|
||||
|
||||
def on_fit_start(self) -> None:
|
||||
@ -40,7 +40,6 @@ class TrainerCallback(ABC):
|
||||
|
||||
|
||||
class MakerCallback(ABC):
|
||||
|
||||
def on_loop_start(self) -> None:
|
||||
pass
|
||||
|
||||
|
@ -30,10 +30,9 @@ def all_reduce_mean(x: float, world_size: int) -> float:
|
||||
|
||||
|
||||
class Timer:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.start_time: Optional[float] = None
|
||||
self.duration: float = 0.
|
||||
self.duration: float = 0.0
|
||||
|
||||
def start(self) -> None:
|
||||
self.start_time = time()
|
||||
@ -42,13 +41,13 @@ class Timer:
|
||||
self.duration += time() - self.start_time
|
||||
|
||||
def reset(self) -> None:
|
||||
self.duration = 0.
|
||||
self.duration = 0.0
|
||||
|
||||
|
||||
class ExperienceMakerPerformanceEvaluator(MakerCallback):
|
||||
|
||||
def __init__(self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int,
|
||||
reward_model_num_params: int) -> None:
|
||||
def __init__(
|
||||
self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int, reward_model_num_params: int
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.world_size = get_world_size()
|
||||
self.actor_num_params = actor_num_params
|
||||
@ -63,7 +62,7 @@ class ExperienceMakerPerformanceEvaluator(MakerCallback):
|
||||
self.make_experience_flop: int = 0
|
||||
|
||||
print_rank_0(
|
||||
f'ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}'
|
||||
f"ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}"
|
||||
)
|
||||
|
||||
def on_make_experience_start(self) -> None:
|
||||
@ -110,27 +109,29 @@ class ExperienceMakerPerformanceEvaluator(MakerCallback):
|
||||
avg_throughput = self.total_samples * self.world_size / (avg_overall_duration + 1e-12)
|
||||
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
|
||||
avg_time_per_sample = (avg_overall_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||
avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / \
|
||||
(self.total_samples * self.world_size)
|
||||
avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / (
|
||||
self.total_samples * self.world_size
|
||||
)
|
||||
avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||
|
||||
print_rank_0(
|
||||
'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
|
||||
+ f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n'
|
||||
+ f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
|
||||
+ f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
|
||||
+ f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
"Making Experience Performance Summary:\n"
|
||||
+ f"Throughput: {avg_throughput:.3f} samples/sec\n"
|
||||
+ f"TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n"
|
||||
+ f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
|
||||
+ f"Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n"
|
||||
+ f"Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n"
|
||||
)
|
||||
|
||||
|
||||
class TrainerPerformanceEvaluator(TrainerCallback):
|
||||
|
||||
def __init__(self,
|
||||
actor_num_params: int,
|
||||
critic_num_params: int,
|
||||
enable_grad_checkpoint: bool = False,
|
||||
ignore_first_episodes: int = 1) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
actor_num_params: int,
|
||||
critic_num_params: int,
|
||||
enable_grad_checkpoint: bool = False,
|
||||
ignore_first_episodes: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.world_size = get_world_size()
|
||||
self.actor_num_params = actor_num_params
|
||||
@ -146,7 +147,7 @@ class TrainerPerformanceEvaluator(TrainerCallback):
|
||||
self.learn_flop: int = 0
|
||||
|
||||
print_rank_0(
|
||||
f'Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}'
|
||||
f"Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}"
|
||||
)
|
||||
|
||||
def on_episode_start(self, episodes: int) -> None:
|
||||
@ -191,7 +192,7 @@ class TrainerPerformanceEvaluator(TrainerCallback):
|
||||
|
||||
def on_fit_end(self) -> None:
|
||||
if self.total_samples == 0:
|
||||
print_rank_0('No samples are collected, skip trainer performance evaluation')
|
||||
print_rank_0("No samples are collected, skip trainer performance evaluation")
|
||||
return
|
||||
avg_train_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
|
||||
avg_update_duration = all_reduce_mean(self.update_timer.duration, self.world_size)
|
||||
@ -204,9 +205,10 @@ class TrainerPerformanceEvaluator(TrainerCallback):
|
||||
avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
|
||||
|
||||
print_rank_0(
|
||||
'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
|
||||
+ f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
|
||||
+ f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
|
||||
+ f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
|
||||
"Learning Performance Summary:\n"
|
||||
+ f"Throughput: {avg_throughput:.3f} samples/sec\n"
|
||||
+ f"TFLOPS per GPU: {avg_learn_tflops:.3f}\n"
|
||||
+ f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
|
||||
+ f"Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n"
|
||||
+ f"Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n"
|
||||
)
|
||||
|
@ -1,20 +1,15 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import random
|
||||
from threading import Lock
|
||||
from typing import Any, List
|
||||
from typing import List
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from coati.experience_buffer import ExperienceBuffer
|
||||
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
from coati.experience_maker.base import Experience
|
||||
|
||||
# from torch.multiprocessing import Queue
|
||||
from ray.util.queue import Queue
|
||||
|
||||
|
||||
class DetachedReplayBuffer:
|
||||
'''
|
||||
"""
|
||||
Detached replay buffer. Share Experience across workers on the same node.
|
||||
Therefore, a trainer node is expected to have only one instance.
|
||||
It is ExperienceMakerHolder's duty to call append(exp) method, remotely.
|
||||
@ -24,7 +19,7 @@ class DetachedReplayBuffer:
|
||||
tp_world_size: Number of workers in the same tp group
|
||||
limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0.
|
||||
cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
|
||||
self.sample_batch_size = sample_batch_size
|
||||
@ -34,23 +29,23 @@ class DetachedReplayBuffer:
|
||||
|
||||
@torch.no_grad()
|
||||
def append(self, experience: Experience) -> None:
|
||||
'''
|
||||
"""
|
||||
Expected to be called remotely.
|
||||
'''
|
||||
"""
|
||||
items = split_experience_batch(experience)
|
||||
self.extend(items)
|
||||
|
||||
@torch.no_grad()
|
||||
def extend(self, items: List[BufferItem]) -> None:
|
||||
'''
|
||||
"""
|
||||
Expected to be called remotely.
|
||||
'''
|
||||
"""
|
||||
self.batch_collector.extend(items)
|
||||
while len(self.batch_collector) >= self.sample_batch_size:
|
||||
items = self.batch_collector[:self.sample_batch_size]
|
||||
items = self.batch_collector[: self.sample_batch_size]
|
||||
experience = make_experience_batch(items)
|
||||
self.items.put(experience, block=True)
|
||||
self.batch_collector = self.batch_collector[self.sample_batch_size:]
|
||||
self.batch_collector = self.batch_collector[self.sample_batch_size :]
|
||||
|
||||
def clear(self) -> None:
|
||||
# self.items.close()
|
||||
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import ray
|
||||
import torch
|
||||
@ -15,7 +15,7 @@ from .utils import is_rank_0
|
||||
|
||||
|
||||
class DetachedTrainer(ABC):
|
||||
'''
|
||||
"""
|
||||
Base class for detached rlhf trainers.
|
||||
'detach' means that the experience maker is detached compared to a normal Trainer.
|
||||
Please set name attribute during init:
|
||||
@ -28,15 +28,17 @@ class DetachedTrainer(ABC):
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
experience_maker_holder_name_list: List[str],
|
||||
train_batch_size: int = 8,
|
||||
buffer_limit: int = 0,
|
||||
dataloader_pin_memory: bool = True,
|
||||
callbacks: List[TrainerCallback] = [],
|
||||
debug: bool = False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
experience_maker_holder_name_list: List[str],
|
||||
train_batch_size: int = 8,
|
||||
buffer_limit: int = 0,
|
||||
dataloader_pin_memory: bool = True,
|
||||
callbacks: List[TrainerCallback] = [],
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit)
|
||||
self.dataloader_pin_memory = dataloader_pin_memory
|
||||
@ -67,18 +69,16 @@ class DetachedTrainer(ABC):
|
||||
def _learn(self, update_steps: int, train_epochs: int) -> None:
|
||||
data = []
|
||||
# warmup
|
||||
pbar = tqdm(range(update_steps), desc=f'Train epoch [1/{train_epochs}]', disable=not is_rank_0())
|
||||
pbar = tqdm(range(update_steps), desc=f"Train epoch [1/{train_epochs}]", disable=not is_rank_0())
|
||||
self._on_epoch_start(0)
|
||||
self._learn_epoch(pbar, data)
|
||||
self._on_epoch_end(0)
|
||||
# item is already a batch
|
||||
dataloader = DataLoader(data,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
pin_memory=self.dataloader_pin_memory,
|
||||
collate_fn=lambda x: x[0])
|
||||
dataloader = DataLoader(
|
||||
data, batch_size=1, shuffle=True, pin_memory=self.dataloader_pin_memory, collate_fn=lambda x: x[0]
|
||||
)
|
||||
for epoch in range(1, train_epochs):
|
||||
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch + 1}/{train_epochs}]', disable=not is_rank_0())
|
||||
pbar = tqdm(dataloader, desc=f"Train epoch [{epoch + 1}/{train_epochs}]", disable=not is_rank_0())
|
||||
self._on_epoch_start(epoch)
|
||||
self._learn_epoch(pbar, data)
|
||||
self._on_epoch_end(epoch)
|
||||
@ -104,7 +104,7 @@ class DetachedTrainer(ABC):
|
||||
|
||||
def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None:
|
||||
self._on_fit_start()
|
||||
for i in tqdm(range(total_steps // update_steps), desc='Trainer', disable=not is_rank_0()):
|
||||
for i in tqdm(range(total_steps // update_steps), desc="Trainer", disable=not is_rank_0()):
|
||||
self._on_episode_start(i)
|
||||
self._learn(update_steps, train_epochs)
|
||||
self._on_update_start()
|
||||
|
@ -1,12 +1,11 @@
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from coati.experience_maker import Experience, NaiveExperienceMaker
|
||||
from coati.experience_maker import Experience
|
||||
from coati.models.base import Actor, Critic
|
||||
from coati.models.loss import PolicyLoss, ValueLoss
|
||||
from coati.trainer.callbacks import Callback
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy
|
||||
from coati.trainer.strategies import GeminiStrategy, LowLevelZeroStrategy, Strategy
|
||||
from torch.optim import Adam
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
@ -14,27 +13,14 @@ from colossalai.nn.optimizer import HybridAdam
|
||||
from .callbacks import TrainerCallback, TrainerPerformanceEvaluator
|
||||
from .detached_trainer_base import DetachedTrainer
|
||||
from .lora_constructor import LoRAConstructor
|
||||
from .utils import (
|
||||
get_actor_from_args,
|
||||
get_critic_from_args,
|
||||
get_model_numel,
|
||||
get_rank,
|
||||
get_strategy_from_args,
|
||||
is_rank_0,
|
||||
set_dist_env,
|
||||
state_dict_to,
|
||||
from .utils import get_model_numel, get_rank, set_dist_env, state_dict_to
|
||||
|
||||
|
||||
@ray.remote(
|
||||
concurrency_groups={"buffer_length": 1, "buffer_append": 1, "buffer_sample": 1, "model_io": 1, "compute": 1}
|
||||
)
|
||||
|
||||
|
||||
@ray.remote(concurrency_groups={
|
||||
"buffer_length": 1,
|
||||
"buffer_append": 1,
|
||||
"buffer_sample": 1,
|
||||
"model_io": 1,
|
||||
"compute": 1
|
||||
})
|
||||
class DetachedPPOTrainer(DetachedTrainer):
|
||||
'''
|
||||
"""
|
||||
Detached Trainer for PPO algorithm
|
||||
Args:
|
||||
strategy (Strategy): the strategy to use for training
|
||||
@ -52,7 +38,7 @@ class DetachedPPOTrainer(DetachedTrainer):
|
||||
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -92,21 +78,24 @@ class DetachedPPOTrainer(DetachedTrainer):
|
||||
self.actor_optim = Adam(self.actor.parameters(), lr=1e-7)
|
||||
self.critic_optim = Adam(self.critic.parameters(), lr=1e-7)
|
||||
|
||||
(self.actor, self.actor_optim), (self.critic, self.critic_optim) = \
|
||||
self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim))
|
||||
(self.actor, self.actor_optim), (self.critic, self.critic_optim) = self.strategy.prepare(
|
||||
(self.actor, self.actor_optim), (self.critic, self.critic_optim)
|
||||
)
|
||||
|
||||
# configure trainer
|
||||
self.actor_loss_fn = PolicyLoss(eps_clip)
|
||||
self.critic_loss_fn = ValueLoss(value_clip)
|
||||
|
||||
super().__init__(experience_maker_holder_name_list,
|
||||
train_batch_size=train_batch_size,
|
||||
buffer_limit=buffer_limit,
|
||||
dataloader_pin_memory=dataloader_pin_memory,
|
||||
callbacks=callbacks,
|
||||
debug=debug)
|
||||
super().__init__(
|
||||
experience_maker_holder_name_list,
|
||||
train_batch_size=train_batch_size,
|
||||
buffer_limit=buffer_limit,
|
||||
dataloader_pin_memory=dataloader_pin_memory,
|
||||
callbacks=callbacks,
|
||||
debug=debug,
|
||||
)
|
||||
if self._debug:
|
||||
print(f'[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}')
|
||||
print(f"[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}")
|
||||
|
||||
self._update_lora_weights = update_lora_weights
|
||||
|
||||
@ -115,7 +104,7 @@ class DetachedPPOTrainer(DetachedTrainer):
|
||||
def _update_remote_makers(self, fully_update: bool = False, **config):
|
||||
# TODO: balance duties
|
||||
if not fully_update:
|
||||
config['requires_grad_only'] = True
|
||||
config["requires_grad_only"] = True
|
||||
self.update_target_holder_list()
|
||||
# mark start, ensure order
|
||||
tasks = []
|
||||
@ -131,7 +120,9 @@ class DetachedPPOTrainer(DetachedTrainer):
|
||||
target_holder.update_experience_maker.remote(
|
||||
new_actor_state_dict=state_dict_shard,
|
||||
new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor),
|
||||
fully_update=fully_update))
|
||||
fully_update=fully_update,
|
||||
)
|
||||
)
|
||||
# sending loop
|
||||
for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update=fully_update, **config):
|
||||
for target_holder in self.target_holder_list:
|
||||
@ -139,7 +130,9 @@ class DetachedPPOTrainer(DetachedTrainer):
|
||||
target_holder.update_experience_maker.remote(
|
||||
new_critic_state_dict=state_dict_shard,
|
||||
new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic),
|
||||
fully_update=fully_update))
|
||||
fully_update=fully_update,
|
||||
)
|
||||
)
|
||||
ray.get(tasks)
|
||||
# mark end
|
||||
for target_holder in self.target_holder_list:
|
||||
@ -152,26 +145,24 @@ class DetachedPPOTrainer(DetachedTrainer):
|
||||
|
||||
num_actions = experience.action_mask.size(1)
|
||||
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
|
||||
actor_loss = self.actor_loss_fn(action_log_probs,
|
||||
experience.action_log_probs,
|
||||
experience.advantages,
|
||||
action_mask=experience.action_mask)
|
||||
actor_loss = self.actor_loss_fn(
|
||||
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
|
||||
)
|
||||
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
|
||||
self.strategy.optimizer_step(self.actor_optim)
|
||||
self.actor_optim.zero_grad()
|
||||
|
||||
values = self.critic(experience.sequences,
|
||||
action_mask=experience.action_mask,
|
||||
attention_mask=experience.attention_mask)
|
||||
critic_loss = self.critic_loss_fn(values,
|
||||
experience.values,
|
||||
experience.reward,
|
||||
action_mask=experience.action_mask)
|
||||
values = self.critic(
|
||||
experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
|
||||
)
|
||||
critic_loss = self.critic_loss_fn(
|
||||
values, experience.values, experience.reward, action_mask=experience.action_mask
|
||||
)
|
||||
|
||||
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
|
||||
self.strategy.optimizer_step(self.critic_optim)
|
||||
self.critic_optim.zero_grad()
|
||||
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
|
||||
return {"actor_loss": actor_loss.item(), "critic_loss": critic_loss.item()}
|
||||
|
||||
def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None:
|
||||
self.strategy.save_model(self.actor, path, only_rank0)
|
||||
|
@ -1,53 +1,49 @@
|
||||
import os
|
||||
import time
|
||||
import tracemalloc
|
||||
from copy import deepcopy
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
|
||||
|
||||
import ray
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
from coati.experience_maker import Experience, ExperienceMaker, NaiveExperienceMaker
|
||||
from coati.experience_buffer.utils import split_experience_batch
|
||||
from coati.experience_maker import Experience, NaiveExperienceMaker
|
||||
from coati.models.base import Actor, Critic, RewardModel
|
||||
from coati.trainer.callbacks import Callback
|
||||
from coati.trainer.strategies import Strategy
|
||||
from coati.trainer.strategies.sampler import DistributedSampler
|
||||
from ray.exceptions import GetTimeoutError
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
|
||||
from .lora_constructor import LoRAConstructor
|
||||
from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env, state_dict_to
|
||||
from .utils import get_model_numel, get_rank, is_rank_0, set_dist_env, state_dict_to
|
||||
|
||||
|
||||
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
|
||||
class ExperienceMakerHolder:
|
||||
'''
|
||||
"""
|
||||
Args:
|
||||
detached_trainer_name_list: str list to get ray actor handles
|
||||
strategy:
|
||||
kl_coef: the coefficient of kl divergence loss
|
||||
sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detached_trainer_name_list: List[str],
|
||||
strategy_fn: Callable[[], Strategy],
|
||||
self,
|
||||
detached_trainer_name_list: List[str],
|
||||
strategy_fn: Callable[[], Strategy],
|
||||
# a function returns (actor, critic, reward_model, initial_model)
|
||||
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
|
||||
env_info: Dict[str, str] = None,
|
||||
sync_models_from_trainers: bool = False,
|
||||
buffer_cpu_offload: bool = True,
|
||||
kl_coef: float = 0.1,
|
||||
callbacks: List[MakerCallback] = [],
|
||||
eval_performance: bool = False,
|
||||
debug: bool = False,
|
||||
update_lora_weights: bool = False,
|
||||
**generate_kwargs):
|
||||
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
|
||||
env_info: Dict[str, str] = None,
|
||||
sync_models_from_trainers: bool = False,
|
||||
buffer_cpu_offload: bool = True,
|
||||
kl_coef: float = 0.1,
|
||||
callbacks: List[MakerCallback] = [],
|
||||
eval_performance: bool = False,
|
||||
debug: bool = False,
|
||||
update_lora_weights: bool = False,
|
||||
**generate_kwargs,
|
||||
):
|
||||
# set environment variables
|
||||
if env_info:
|
||||
set_dist_env(env_info=env_info)
|
||||
@ -66,8 +62,9 @@ class ExperienceMakerHolder:
|
||||
critic_numel = get_model_numel(critic)
|
||||
initial_model_numel = get_model_numel(initial_model)
|
||||
reward_model_numel = get_model_numel(reward_model)
|
||||
evaluator = ExperienceMakerPerformanceEvaluator(actor_numel, critic_numel, initial_model_numel,
|
||||
reward_model_numel)
|
||||
evaluator = ExperienceMakerPerformanceEvaluator(
|
||||
actor_numel, critic_numel, initial_model_numel, reward_model_numel
|
||||
)
|
||||
callbacks = callbacks + [evaluator]
|
||||
|
||||
actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model)
|
||||
@ -89,9 +86,9 @@ class ExperienceMakerHolder:
|
||||
self._target_idx = 0
|
||||
|
||||
if self._debug:
|
||||
print(f'[maker{get_rank()}] will send items to {self._detached_trainer_name_list}')
|
||||
print(f"[maker{get_rank()}] will send items to {self._detached_trainer_name_list}")
|
||||
if not self._is_fully_initialized:
|
||||
print(f'[maker{get_rank()}] Waiting for INIT')
|
||||
print(f"[maker{get_rank()}] Waiting for INIT")
|
||||
|
||||
def _get_ready(self):
|
||||
while not self._fully_initialized():
|
||||
@ -136,7 +133,7 @@ class ExperienceMakerHolder:
|
||||
self._on_make_experience_end(experience)
|
||||
self._on_send_start()
|
||||
if self.buffer_cpu_offload:
|
||||
experience.to_device('cpu')
|
||||
experience.to_device("cpu")
|
||||
self._send_items(experience)
|
||||
self._on_send_end()
|
||||
self._on_batch_end()
|
||||
@ -155,7 +152,7 @@ class ExperienceMakerHolder:
|
||||
if num_steps > 0:
|
||||
# ignore num epochs
|
||||
it = iter(dataloader)
|
||||
for _ in tqdm(range(num_steps), desc='ExperienceMaker', disable=not is_rank_0()):
|
||||
for _ in tqdm(range(num_steps), desc="ExperienceMaker", disable=not is_rank_0()):
|
||||
try:
|
||||
batch = next(it)
|
||||
except StopIteration:
|
||||
@ -163,7 +160,7 @@ class ExperienceMakerHolder:
|
||||
batch = next(it)
|
||||
self._inference_step(batch)
|
||||
else:
|
||||
with tqdm(total=num_epochs * len(dataloader), desc='ExperienceMaker', disable=not is_rank_0()) as pbar:
|
||||
with tqdm(total=num_epochs * len(dataloader), desc="ExperienceMaker", disable=not is_rank_0()) as pbar:
|
||||
for _ in range(num_epochs):
|
||||
for batch in dataloader:
|
||||
self._inference_step(batch)
|
||||
@ -171,22 +168,24 @@ class ExperienceMakerHolder:
|
||||
self._on_loop_end()
|
||||
|
||||
@ray.method(concurrency_group="model_io")
|
||||
def update_experience_maker(self,
|
||||
new_actor_state_dict: Dict[str, Any] = None,
|
||||
new_actor_lora_config_dict: Dict[str, Any] = None,
|
||||
new_critic_state_dict: Dict[str, Any] = None,
|
||||
new_critic_lora_config_dict: Dict[str, Any] = None,
|
||||
fully_update: bool = False,
|
||||
chunk_start: bool = None,
|
||||
chunk_end: bool = None):
|
||||
'''
|
||||
called by trainer
|
||||
chunk_start: Set True at the first call. Before sending state_dict calls
|
||||
chunk_end: Set True at the last call. After sending state_dict calls.
|
||||
fully_update: Set True if you want to sync models when initializing
|
||||
def update_experience_maker(
|
||||
self,
|
||||
new_actor_state_dict: Dict[str, Any] = None,
|
||||
new_actor_lora_config_dict: Dict[str, Any] = None,
|
||||
new_critic_state_dict: Dict[str, Any] = None,
|
||||
new_critic_lora_config_dict: Dict[str, Any] = None,
|
||||
fully_update: bool = False,
|
||||
chunk_start: bool = None,
|
||||
chunk_end: bool = None,
|
||||
):
|
||||
"""
|
||||
called by trainer
|
||||
chunk_start: Set True at the first call. Before sending state_dict calls
|
||||
chunk_end: Set True at the last call. After sending state_dict calls.
|
||||
fully_update: Set True if you want to sync models when initializing
|
||||
|
||||
TODO: load_state_dict integrate with model-sharding strategy
|
||||
'''
|
||||
TODO: load_state_dict integrate with model-sharding strategy
|
||||
"""
|
||||
_watch_memory = self._debug
|
||||
if chunk_start:
|
||||
if self._debug:
|
||||
@ -202,18 +201,22 @@ class ExperienceMakerHolder:
|
||||
else:
|
||||
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
|
||||
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(
|
||||
new_actor_state_dict, new_actor_lora_config_dict)
|
||||
new_actor_state_dict, new_actor_lora_config_dict
|
||||
)
|
||||
self.actor_lora_constructor.load_state_dict_increase(
|
||||
self.experience_maker.actor.model, state_dict_increase)
|
||||
self.experience_maker.actor.model, state_dict_increase
|
||||
)
|
||||
if new_critic_state_dict is not None:
|
||||
if not self._update_lora_weights or fully_update:
|
||||
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
|
||||
else:
|
||||
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
|
||||
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(
|
||||
new_critic_state_dict, new_critic_lora_config_dict)
|
||||
new_critic_state_dict, new_critic_lora_config_dict
|
||||
)
|
||||
self.critic_lora_constructor.load_state_dict_increase(
|
||||
self.experience_maker.critic, state_dict_increase)
|
||||
self.experience_maker.critic, state_dict_increase
|
||||
)
|
||||
|
||||
# the lock must be released after both actor and critic being updated
|
||||
if chunk_end:
|
||||
@ -262,10 +265,10 @@ def _set_default_generate_kwargs(generate_kwargs: dict, actor: Actor) -> None:
|
||||
origin_model = actor.model
|
||||
new_kwargs = {**generate_kwargs}
|
||||
# use huggingface models method directly
|
||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
|
||||
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
|
||||
if "prepare_inputs_fn" not in generate_kwargs and hasattr(origin_model, "prepare_inputs_for_generation"):
|
||||
new_kwargs["prepare_inputs_fn"] = origin_model.prepare_inputs_for_generation
|
||||
|
||||
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'):
|
||||
new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation
|
||||
if "update_model_kwargs_fn" not in generate_kwargs and hasattr(origin_model, "_update_model_kwargs_for_generation"):
|
||||
new_kwargs["update_model_kwargs_fn"] = origin_model._update_model_kwargs_for_generation
|
||||
|
||||
return new_kwargs
|
||||
|
@ -1,11 +1,9 @@
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.models.lora import LoraLinear
|
||||
from loralib.layers import LoRALayer
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -17,7 +15,7 @@ class LoRAConfig:
|
||||
|
||||
|
||||
class LoRAConstructor:
|
||||
'''
|
||||
"""
|
||||
Tools for reconstructing a model from a remote LoRA model.
|
||||
(Transferring only LoRA data costs much less!)
|
||||
Usage:
|
||||
@ -36,7 +34,7 @@ class LoRAConstructor:
|
||||
Step 5 (Receiver):
|
||||
load_state_dict_increase()
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.lora_config_dict = None
|
||||
@ -45,10 +43,10 @@ class LoRAConstructor:
|
||||
self.lora_config_dict = lora_config_dict
|
||||
|
||||
def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any]):
|
||||
'''
|
||||
xxx.lora_A, xxx.lora_B -->> xxx.weight
|
||||
Warning: the xxx.weight here is the increment actually.
|
||||
'''
|
||||
"""
|
||||
xxx.lora_A, xxx.lora_B -->> xxx.weight
|
||||
Warning: the xxx.weight here is the increment actually.
|
||||
"""
|
||||
if lora_config_dict is not None:
|
||||
self.register_lora_config(lora_config_dict)
|
||||
|
||||
@ -56,24 +54,25 @@ class LoRAConstructor:
|
||||
config_iter = iter(self.lora_config_dict.items())
|
||||
lora_A, lora_B, layer_prefix = None, None, None
|
||||
for k, v in state_dict_lora.items():
|
||||
if k.rpartition('.')[-1] == 'lora_A':
|
||||
if k.rpartition(".")[-1] == "lora_A":
|
||||
lora_A = v
|
||||
layer_prefix = k.rpartition('.')[0]
|
||||
elif k.rpartition('.')[-1] == 'lora_B':
|
||||
assert layer_prefix == k.rpartition('.')[0], "unmatched (lora_A, lora_B) pair"
|
||||
layer_prefix = k.rpartition(".")[0]
|
||||
elif k.rpartition(".")[-1] == "lora_B":
|
||||
assert layer_prefix == k.rpartition(".")[0], "unmatched (lora_A, lora_B) pair"
|
||||
layer_prefix_2, config = next(config_iter)
|
||||
assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair"
|
||||
lora_B = v
|
||||
weight_data_increase = self._compute(lora_A, lora_B, config)
|
||||
state_dict_increase[layer_prefix + '.weight'] = weight_data_increase
|
||||
state_dict_increase[layer_prefix + ".weight"] = weight_data_increase
|
||||
lora_A, lora_B, layer_prefix = None, None, None
|
||||
else:
|
||||
raise ValueError('unexpected key')
|
||||
raise ValueError("unexpected key")
|
||||
return state_dict_increase
|
||||
|
||||
def _compute(self, lora_A, lora_B, config=LoRAConfig()):
|
||||
def T(w):
|
||||
return w.T if config.fan_in_fan_out else w
|
||||
|
||||
if config.r > 0:
|
||||
scaling = config.lora_alpha / config.r
|
||||
weight_data_increase = T(lora_B @ lora_A) * scaling
|
||||
@ -81,21 +80,21 @@ class LoRAConstructor:
|
||||
return 0
|
||||
|
||||
def load_state_dict_increase(self, model: nn.Module, state_dict_increase: Dict[str, Any]):
|
||||
'''
|
||||
"""
|
||||
The final reconstruction step
|
||||
'''
|
||||
"""
|
||||
# naive approach
|
||||
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increase.items()}, strict=False)
|
||||
|
||||
@staticmethod
|
||||
def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False):
|
||||
'''
|
||||
"""
|
||||
if keep_non_lora, also return non_lora state_dict
|
||||
'''
|
||||
"""
|
||||
state_dict_lora = OrderedDict()
|
||||
state_dict_non_lora = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
if 'lora_A' in k or 'lora_B' in k:
|
||||
if "lora_A" in k or "lora_B" in k:
|
||||
state_dict_lora[k] = v
|
||||
elif keep_non_lora:
|
||||
state_dict_non_lora[k] = v
|
||||
@ -106,17 +105,19 @@ class LoRAConstructor:
|
||||
|
||||
@staticmethod
|
||||
def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]:
|
||||
'''
|
||||
"""
|
||||
extract LoraLinear model.
|
||||
return OrderedDict(): name -> LoRAConfig
|
||||
'''
|
||||
"""
|
||||
lora_config_dict = OrderedDict()
|
||||
|
||||
for name, child in model.named_modules():
|
||||
if isinstance(child, LoraLinear):
|
||||
lora_config_dict[name] = LoRAConfig(r=child.r,
|
||||
lora_alpha=child.lora_alpha,
|
||||
lora_dropout=child.lora_dropout,
|
||||
fan_in_fan_out=child.fan_in_fan_out)
|
||||
lora_config_dict[name] = LoRAConfig(
|
||||
r=child.r,
|
||||
lora_alpha=child.lora_alpha,
|
||||
lora_dropout=child.lora_dropout,
|
||||
fan_in_fan_out=child.fan_in_fan_out,
|
||||
)
|
||||
|
||||
return lora_config_dict
|
||||
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -10,7 +10,7 @@ from coati.models.gpt import GPTRM, GPTActor, GPTCritic
|
||||
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
|
||||
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer
|
||||
|
||||
|
||||
def is_rank_0() -> bool:
|
||||
@ -26,13 +26,13 @@ def get_world_size() -> int:
|
||||
|
||||
|
||||
def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
|
||||
if model == 'gpt2':
|
||||
if model == "gpt2":
|
||||
actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||
elif model == 'bloom':
|
||||
elif model == "bloom":
|
||||
actor = BLOOMActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||
elif model == 'opt':
|
||||
elif model == "opt":
|
||||
actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||
elif model == 'llama':
|
||||
elif model == "llama":
|
||||
actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
|
||||
else:
|
||||
raise ValueError(f'Unsupported actor model "{model}"')
|
||||
@ -40,13 +40,13 @@ def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_ra
|
||||
|
||||
|
||||
def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
|
||||
if model == 'gpt2':
|
||||
if model == "gpt2":
|
||||
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||
elif model == 'bloom':
|
||||
elif model == "bloom":
|
||||
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||
elif model == 'opt':
|
||||
elif model == "opt":
|
||||
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||
elif model == 'llama':
|
||||
elif model == "llama":
|
||||
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{model}"')
|
||||
@ -54,13 +54,13 @@ def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_r
|
||||
|
||||
|
||||
def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
|
||||
if model == 'gpt2':
|
||||
if model == "gpt2":
|
||||
reward_model = GPTRM(pretrained=pretrained, config=config)
|
||||
elif model == 'bloom':
|
||||
elif model == "bloom":
|
||||
reward_model = BLOOMRM(pretrained=pretrained, config=config)
|
||||
elif model == 'opt':
|
||||
elif model == "opt":
|
||||
reward_model = OPTRM(pretrained=pretrained, config=config)
|
||||
elif model == 'llama':
|
||||
elif model == "llama":
|
||||
reward_model = LlamaRM(pretrained=pretrained, config=config)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{model}"')
|
||||
@ -68,29 +68,29 @@ def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
|
||||
|
||||
|
||||
def get_strategy_from_args(strategy: str):
|
||||
if strategy == 'ddp':
|
||||
if strategy == "ddp":
|
||||
strategy_ = DDPStrategy()
|
||||
elif strategy == 'colossalai_gemini':
|
||||
strategy_ = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
|
||||
elif strategy == 'colossalai_zero2':
|
||||
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
||||
elif strategy == 'colossalai_gemini_cpu':
|
||||
strategy_ = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
|
||||
elif strategy == 'colossalai_zero2_cpu':
|
||||
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
|
||||
elif strategy == "colossalai_gemini":
|
||||
strategy_ = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
|
||||
elif strategy == "colossalai_zero2":
|
||||
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
elif strategy == "colossalai_gemini_cpu":
|
||||
strategy_ = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
|
||||
elif strategy == "colossalai_zero2_cpu":
|
||||
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{strategy}"')
|
||||
return strategy_
|
||||
|
||||
|
||||
def get_tokenizer_from_args(model: str, **kwargs):
|
||||
if model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
elif model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
|
||||
elif model == 'opt':
|
||||
if model == "gpt2":
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
elif model == "bloom":
|
||||
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
|
||||
elif model == "opt":
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
elif model == 'llama':
|
||||
elif model == "llama":
|
||||
pretrain_path = kwargs["pretrain"]
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
|
||||
else:
|
||||
@ -101,11 +101,11 @@ def get_tokenizer_from_args(model: str, **kwargs):
|
||||
|
||||
|
||||
def set_dist_env(env_info: Dict[str, str]):
|
||||
os.environ["RANK"] = env_info['rank']
|
||||
os.environ["LOCAL_RANK"] = env_info['local_rank']
|
||||
os.environ["WORLD_SIZE"] = env_info['world_size']
|
||||
os.environ['MASTER_PORT'] = env_info['master_port']
|
||||
os.environ['MASTER_ADDR'] = env_info['master_addr']
|
||||
os.environ["RANK"] = env_info["rank"]
|
||||
os.environ["LOCAL_RANK"] = env_info["local_rank"]
|
||||
os.environ["WORLD_SIZE"] = env_info["world_size"]
|
||||
os.environ["MASTER_PORT"] = env_info["master_port"]
|
||||
os.environ["MASTER_ADDR"] = env_info["master_addr"]
|
||||
|
||||
|
||||
def get_model_numel(model: nn.Module) -> int:
|
||||
@ -128,12 +128,12 @@ def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: i
|
||||
return target_receivers
|
||||
|
||||
|
||||
def state_dict_to(state_dict: Dict[str, Any],
|
||||
dtype: torch.dtype = torch.float16,
|
||||
device: torch.device = torch.device('cpu')):
|
||||
'''
|
||||
keep state_dict intact
|
||||
'''
|
||||
def state_dict_to(
|
||||
state_dict: Dict[str, Any], dtype: torch.dtype = torch.float16, device: torch.device = torch.device("cpu")
|
||||
):
|
||||
"""
|
||||
keep state_dict intact
|
||||
"""
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
new_state_dict[k] = v.to(dtype=dtype, device=device)
|
||||
|
@ -3,8 +3,4 @@ from .ppo import PPOTrainer
|
||||
from .rm import RewardModelTrainer
|
||||
from .sft import SFTTrainer
|
||||
|
||||
__all__ = [
|
||||
'SLTrainer', 'OnPolicyTrainer',
|
||||
'RewardModelTrainer', 'SFTTrainer',
|
||||
'PPOTrainer'
|
||||
]
|
||||
__all__ = ["SLTrainer", "OnPolicyTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer"]
|
||||
|
@ -68,12 +68,14 @@ class OnPolicyTrainer(ABC):
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
strategy: Strategy,
|
||||
data_buffer: NaiveExperienceBuffer,
|
||||
sample_buffer: bool,
|
||||
dataloader_pin_memory: bool,
|
||||
callbacks: List[Callback] = []) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
strategy: Strategy,
|
||||
data_buffer: NaiveExperienceBuffer,
|
||||
sample_buffer: bool,
|
||||
dataloader_pin_memory: bool,
|
||||
callbacks: List[Callback] = [],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.strategy = strategy
|
||||
self.data_buffer = data_buffer
|
||||
|
@ -2,4 +2,4 @@ from .base import Callback
|
||||
from .performance_evaluator import PerformanceEvaluator
|
||||
from .save_checkpoint import SaveCheckpoint
|
||||
|
||||
__all__ = ['Callback', 'PerformanceEvaluator', 'SaveCheckpoint']
|
||||
__all__ = ["Callback", "PerformanceEvaluator", "SaveCheckpoint"]
|
||||
|
@ -5,7 +5,7 @@ from coati.experience_maker import Experience
|
||||
|
||||
class Callback(ABC):
|
||||
"""
|
||||
Base callback class. It defines the interface for callbacks.
|
||||
Base callback class. It defines the interface for callbacks.
|
||||
"""
|
||||
|
||||
def on_fit_start(self) -> None:
|
||||
|
@ -21,9 +21,9 @@ def print_rank_0(*args, **kwargs) -> None:
|
||||
|
||||
def divide(x: float, y: float) -> float:
|
||||
if y == 0:
|
||||
return float('inf')
|
||||
elif y == float('inf'):
|
||||
return float('nan')
|
||||
return float("inf")
|
||||
elif y == float("inf"):
|
||||
return float("nan")
|
||||
return x / y
|
||||
|
||||
|
||||
@ -38,10 +38,9 @@ def all_reduce_mean(x: float, world_size: int) -> float:
|
||||
|
||||
|
||||
class Timer:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.start_time: Optional[float] = None
|
||||
self.duration: float = 0.
|
||||
self.duration: float = 0.0
|
||||
|
||||
def start(self) -> None:
|
||||
self.start_time = time()
|
||||
@ -52,7 +51,7 @@ class Timer:
|
||||
self.start_time = None
|
||||
|
||||
def reset(self) -> None:
|
||||
self.duration = 0.
|
||||
self.duration = 0.0
|
||||
|
||||
|
||||
class PerformanceEvaluator(Callback):
|
||||
@ -67,13 +66,15 @@ class PerformanceEvaluator(Callback):
|
||||
ignore_episodes: The number of episodes to ignore when calculating the performance.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
actor_num_params: int,
|
||||
critic_num_params: int,
|
||||
initial_model_num_params: int,
|
||||
reward_model_num_params: int,
|
||||
enable_grad_checkpoint: bool = False,
|
||||
ignore_episodes: int = 0) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
actor_num_params: int,
|
||||
critic_num_params: int,
|
||||
initial_model_num_params: int,
|
||||
reward_model_num_params: int,
|
||||
enable_grad_checkpoint: bool = False,
|
||||
ignore_episodes: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.world_size = get_world_size()
|
||||
self.actor_num_params = actor_num_params
|
||||
@ -155,8 +156,9 @@ class PerformanceEvaluator(Callback):
|
||||
avg_learn_duration = all_reduce_mean(self.learn_timer.duration, self.world_size)
|
||||
avg_overall_duration = all_reduce_mean(self.overall_timer.duration, self.world_size)
|
||||
|
||||
avg_make_experience_throughput = self.make_experience_num_samples * \
|
||||
self.world_size / (avg_make_experience_duration + 1e-12)
|
||||
avg_make_experience_throughput = (
|
||||
self.make_experience_num_samples * self.world_size / (avg_make_experience_duration + 1e-12)
|
||||
)
|
||||
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
|
||||
|
||||
avg_learn_throughput = self.learn_num_samples * self.world_size / (avg_learn_duration + 1e-12)
|
||||
@ -171,13 +173,11 @@ class PerformanceEvaluator(Callback):
|
||||
learn_time_per_sample = divide(avg_learn_duration, num_effective_samples)
|
||||
|
||||
print_rank_0(
|
||||
f'Performance summary:\n'
|
||||
+ f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n'
|
||||
|
||||
+ f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n'
|
||||
+ f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n'
|
||||
+ f'Overall time per sample: {overall_time_per_sample:.2f} s\n'
|
||||
+ f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n'
|
||||
|
||||
+ f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%'
|
||||
f"Performance summary:\n"
|
||||
+ f"Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n"
|
||||
+ f"Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n"
|
||||
+ f"Overall throughput: {avg_overall_throughput:.2f} samples/s\n"
|
||||
+ f"Overall time per sample: {overall_time_per_sample:.2f} s\n"
|
||||
+ f"Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n"
|
||||
+ f"Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%"
|
||||
)
|
||||
|
@ -36,34 +36,35 @@ class SaveCheckpoint(Callback):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
interval: int,
|
||||
strategy: Strategy,
|
||||
actor: nn.Module = None,
|
||||
critic: nn.Module = None,
|
||||
actor_optim: Optimizer = None,
|
||||
critic_optim: Optimizer = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
interval: int,
|
||||
strategy: Strategy,
|
||||
actor: nn.Module = None,
|
||||
critic: nn.Module = None,
|
||||
actor_optim: Optimizer = None,
|
||||
critic_optim: Optimizer = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.path = os.path.join(path, 'checkpoint')
|
||||
self.path = os.path.join(path, "checkpoint")
|
||||
self.interval = interval
|
||||
self.strategy = strategy
|
||||
self.model_dict = {'actor': [actor, actor_optim], 'critic': [critic, critic_optim]}
|
||||
self.model_dict = {"actor": [actor, actor_optim], "critic": [critic, critic_optim]}
|
||||
|
||||
def on_episode_end(self, episode: int) -> None:
|
||||
if (episode + 1) % self.interval != 0:
|
||||
return
|
||||
base_path = os.path.join(self.path, f'episode_{episode}')
|
||||
base_path = os.path.join(self.path, f"episode_{episode}")
|
||||
if not os.path.exists(base_path):
|
||||
os.makedirs(base_path)
|
||||
|
||||
for model in self.model_dict.keys():
|
||||
|
||||
# save model
|
||||
if self.model_dict[model][0] is None:
|
||||
# saving only optimizer states is meaningless, so it would be skipped
|
||||
continue
|
||||
model_path = os.path.join(base_path, f'{model}.pt')
|
||||
model_path = os.path.join(base_path, f"{model}.pt")
|
||||
self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True)
|
||||
|
||||
# save optimizer
|
||||
@ -71,5 +72,5 @@ class SaveCheckpoint(Callback):
|
||||
continue
|
||||
only_rank0 = not isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy))
|
||||
rank = 0 if is_rank_0() else dist.get_rank()
|
||||
optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt')
|
||||
optim_path = os.path.join(base_path, f"{model}-optim-rank-{rank}.pt")
|
||||
self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0)
|
||||
|
@ -8,7 +8,7 @@ from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
|
||||
from coati.models.utils import calc_action_log_probs
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torch.utils.data import DistributedSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
@ -24,11 +24,11 @@ def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, acto
|
||||
hf_model = get_base_model(unwrapper_model)
|
||||
new_kwargs = {**generate_kwargs}
|
||||
# use huggingface models method directly
|
||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'):
|
||||
new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation
|
||||
if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"):
|
||||
new_kwargs["prepare_inputs_fn"] = hf_model.prepare_inputs_for_generation
|
||||
|
||||
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'):
|
||||
new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation
|
||||
if "update_model_kwargs_fn" not in generate_kwargs and hasattr(hf_model, "_update_model_kwargs_for_generation"):
|
||||
new_kwargs["update_model_kwargs_fn"] = hf_model._update_model_kwargs_for_generation
|
||||
|
||||
return new_kwargs
|
||||
|
||||
@ -60,38 +60,34 @@ class PPOTrainer(OnPolicyTrainer):
|
||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
strategy: Strategy,
|
||||
actor: Actor,
|
||||
critic: Critic,
|
||||
reward_model: nn.Module,
|
||||
initial_model: Actor,
|
||||
actor_optim: Optimizer,
|
||||
critic_optim: Optimizer,
|
||||
kl_coef: float = 0.1,
|
||||
ptx_coef: float = 0.9,
|
||||
train_batch_size: int = 8,
|
||||
buffer_limit: int = 0,
|
||||
buffer_cpu_offload: bool = True,
|
||||
eps_clip: float = 0.2,
|
||||
vf_coef: float = 1.0,
|
||||
value_clip: float = 0.4,
|
||||
sample_buffer: bool = False,
|
||||
dataloader_pin_memory: bool = True,
|
||||
offload_inference_models: bool = True,
|
||||
callbacks: List[Callback] = [],
|
||||
**generate_kwargs
|
||||
) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
strategy: Strategy,
|
||||
actor: Actor,
|
||||
critic: Critic,
|
||||
reward_model: nn.Module,
|
||||
initial_model: Actor,
|
||||
actor_optim: Optimizer,
|
||||
critic_optim: Optimizer,
|
||||
kl_coef: float = 0.1,
|
||||
ptx_coef: float = 0.9,
|
||||
train_batch_size: int = 8,
|
||||
buffer_limit: int = 0,
|
||||
buffer_cpu_offload: bool = True,
|
||||
eps_clip: float = 0.2,
|
||||
vf_coef: float = 1.0,
|
||||
value_clip: float = 0.4,
|
||||
sample_buffer: bool = False,
|
||||
dataloader_pin_memory: bool = True,
|
||||
offload_inference_models: bool = True,
|
||||
callbacks: List[Callback] = [],
|
||||
**generate_kwargs,
|
||||
) -> None:
|
||||
if isinstance(strategy, GeminiStrategy):
|
||||
assert not offload_inference_models, \
|
||||
"GeminiPlugin is not compatible with manual model.to('cpu')"
|
||||
assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')"
|
||||
|
||||
data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
||||
super().__init__(
|
||||
strategy, data_buffer,
|
||||
sample_buffer, dataloader_pin_memory,
|
||||
callbacks
|
||||
)
|
||||
super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks)
|
||||
|
||||
self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
|
||||
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
|
||||
@ -130,18 +126,16 @@ class PPOTrainer(OnPolicyTrainer):
|
||||
num_actions = experience.action_mask.size(1)
|
||||
actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask)
|
||||
action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions)
|
||||
actor_loss = self.actor_loss_fn(action_log_probs,
|
||||
experience.action_log_probs,
|
||||
experience.advantages,
|
||||
action_mask=experience.action_mask)
|
||||
actor_loss = self.actor_loss_fn(
|
||||
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
|
||||
)
|
||||
|
||||
# ptx loss
|
||||
if self.ptx_coef != 0:
|
||||
batch = self.pretrain_dataloader.next()
|
||||
batch = to_device(batch, self.device)
|
||||
ptx_log_probs = self.actor(batch['input_ids'],
|
||||
attention_mask=batch['attention_mask'])['logits']
|
||||
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels'])
|
||||
ptx_log_probs = self.actor(batch["input_ids"], attention_mask=batch["attention_mask"])["logits"]
|
||||
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch["labels"])
|
||||
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
|
||||
|
||||
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
|
||||
@ -149,24 +143,23 @@ class PPOTrainer(OnPolicyTrainer):
|
||||
self.actor_optim.zero_grad()
|
||||
|
||||
# value loss
|
||||
values = self.critic(experience.sequences,
|
||||
action_mask=experience.action_mask,
|
||||
attention_mask=experience.attention_mask)
|
||||
critic_loss = self.critic_loss_fn(values,
|
||||
experience.values,
|
||||
experience.reward,
|
||||
action_mask=experience.action_mask)
|
||||
values = self.critic(
|
||||
experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
|
||||
)
|
||||
critic_loss = self.critic_loss_fn(
|
||||
values, experience.values, experience.reward, action_mask=experience.action_mask
|
||||
)
|
||||
critic_loss = critic_loss * self.vf_coef
|
||||
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
|
||||
self.strategy.optimizer_step(self.critic_optim)
|
||||
self.critic_optim.zero_grad()
|
||||
|
||||
return {'reward': experience.reward.mean().item()}
|
||||
return {"reward": experience.reward.mean().item()}
|
||||
|
||||
def _learn(self, update_step: int):
|
||||
if self.offload_inference_models:
|
||||
self.experience_maker.initial_model.to('cpu')
|
||||
self.experience_maker.reward_model.to('cpu')
|
||||
self.experience_maker.initial_model.to("cpu")
|
||||
self.experience_maker.reward_model.to("cpu")
|
||||
|
||||
# buffer may be empty at first, we should rebuild at each training
|
||||
if self.sample_buffer:
|
||||
@ -178,11 +171,7 @@ class PPOTrainer(OnPolicyTrainer):
|
||||
else:
|
||||
if isinstance(self.dataloader.sampler, DistributedSampler):
|
||||
self.dataloader.sampler.set_epoch(update_step)
|
||||
pbar = tqdm(
|
||||
self.dataloader,
|
||||
desc=f'Train epoch [{update_step + 1}]',
|
||||
disable=not is_rank_0()
|
||||
)
|
||||
pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0())
|
||||
for experience in pbar:
|
||||
self._on_learn_batch_start()
|
||||
experience.to_device(self.device)
|
||||
|
@ -62,18 +62,15 @@ class RewardModelTrainer(SLTrainer):
|
||||
|
||||
if is_rank_0():
|
||||
log = pd.DataFrame(
|
||||
[[(epoch + 1) * len(self.train_dataloader),
|
||||
self.loss.item(), self.dist, self.acc]],
|
||||
columns=['step', 'loss', 'dist', 'acc']
|
||||
[[(epoch + 1) * len(self.train_dataloader), self.loss.item(), self.dist, self.acc]],
|
||||
columns=["step", "loss", "dist", "acc"],
|
||||
)
|
||||
log.to_csv('log.csv', mode='a', header=False, index=False)
|
||||
log.to_csv("log.csv", mode="a", header=False, index=False)
|
||||
|
||||
def _train(self, epoch):
|
||||
self.model.train()
|
||||
step_bar = tqdm.trange(
|
||||
len(self.train_dataloader),
|
||||
desc='Train step of epoch %d' % epoch,
|
||||
disable=not is_rank_0()
|
||||
len(self.train_dataloader), desc="Train step of epoch %d" % epoch, disable=not is_rank_0()
|
||||
)
|
||||
cnt = 0
|
||||
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
|
||||
@ -93,10 +90,7 @@ class RewardModelTrainer(SLTrainer):
|
||||
step_bar.update()
|
||||
step_bar.close()
|
||||
|
||||
def _before_fit(self,
|
||||
train_dataloader: DataLoader,
|
||||
valid_dataloader: DataLoader,
|
||||
eval_dataloader: DataLoader):
|
||||
def _before_fit(self, train_dataloader: DataLoader, valid_dataloader: DataLoader, eval_dataloader: DataLoader):
|
||||
"""
|
||||
Args:
|
||||
train_dataloader (DataLoader): the dataloader to use for training
|
||||
@ -104,7 +98,7 @@ class RewardModelTrainer(SLTrainer):
|
||||
eval_dataloader (DataLoader): the dataloader to use for evaluation
|
||||
"""
|
||||
super()._before_fit()
|
||||
self.datetime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
||||
self.datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
self.train_dataloader = train_dataloader
|
||||
self.valid_dataloader = valid_dataloader
|
||||
|
@ -39,8 +39,9 @@ class SFTTrainer(SLTrainer):
|
||||
accumulation_steps: int = 8,
|
||||
) -> None:
|
||||
if accumulation_steps > 1:
|
||||
assert not isinstance(strategy, GeminiStrategy), \
|
||||
"Accumulation steps are not supported in stage 3 of ColossalAI"
|
||||
assert not isinstance(
|
||||
strategy, GeminiStrategy
|
||||
), "Accumulation steps are not supported in stage 3 of ColossalAI"
|
||||
|
||||
super().__init__(strategy, max_epochs, model, optim)
|
||||
|
||||
@ -50,15 +51,11 @@ class SFTTrainer(SLTrainer):
|
||||
def _train(self, epoch: int):
|
||||
self.model.train()
|
||||
for batch_id, batch in enumerate(self.train_dataloader):
|
||||
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
if "attention_mask" in batch:
|
||||
outputs = self.model(batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
labels=batch["labels"])
|
||||
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
||||
else:
|
||||
outputs = self.model(batch["input_ids"],
|
||||
labels=batch["labels"])
|
||||
outputs = self.model(batch["input_ids"], labels=batch["labels"])
|
||||
|
||||
loss = outputs.loss
|
||||
loss = loss / self.accumulation_steps
|
||||
@ -73,12 +70,14 @@ class SFTTrainer(SLTrainer):
|
||||
self.optimizer.zero_grad()
|
||||
self.scheduler.step()
|
||||
if is_rank_0() and self.use_wandb:
|
||||
wandb.log({
|
||||
"loss": self.total_loss / self.accumulation_steps,
|
||||
"lr": self.scheduler.get_last_lr()[0],
|
||||
"epoch": epoch,
|
||||
"batch_id": batch_id
|
||||
})
|
||||
wandb.log(
|
||||
{
|
||||
"loss": self.total_loss / self.accumulation_steps,
|
||||
"lr": self.scheduler.get_last_lr()[0],
|
||||
"epoch": epoch,
|
||||
"batch_id": batch_id,
|
||||
}
|
||||
)
|
||||
self.total_loss = 0
|
||||
self.step_bar.update()
|
||||
|
||||
@ -89,9 +88,9 @@ class SFTTrainer(SLTrainer):
|
||||
loss_sum, num_seen = 0, 0
|
||||
for batch in self.eval_dataloader:
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
outputs = self.model(batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
labels=batch["labels"])
|
||||
outputs = self.model(
|
||||
batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]
|
||||
)
|
||||
loss = outputs.loss
|
||||
|
||||
loss_sum += loss.item()
|
||||
@ -99,13 +98,15 @@ class SFTTrainer(SLTrainer):
|
||||
|
||||
loss_mean = loss_sum / num_seen
|
||||
if dist.get_rank() == 0:
|
||||
self.logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}')
|
||||
self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}")
|
||||
|
||||
def _before_fit(self,
|
||||
train_dataloader: DataLoader,
|
||||
eval_dataloader: Optional[DataLoader] = None,
|
||||
logger: Optional[DistributedLogger] = None,
|
||||
use_wandb: bool = False):
|
||||
def _before_fit(
|
||||
self,
|
||||
train_dataloader: DataLoader,
|
||||
eval_dataloader: Optional[DataLoader] = None,
|
||||
logger: Optional[DistributedLogger] = None,
|
||||
use_wandb: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
train_dataloader: the dataloader to use for training
|
||||
@ -124,6 +125,6 @@ class SFTTrainer(SLTrainer):
|
||||
self.no_epoch_bar = True
|
||||
self.step_bar = tqdm.trange(
|
||||
len(self.train_dataloader) // self.accumulation_steps * self.max_epochs,
|
||||
desc=f'steps',
|
||||
disable=not is_rank_0()
|
||||
desc=f"steps",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
|
@ -2,7 +2,4 @@ from .base import Strategy
|
||||
from .colossalai import GeminiStrategy, LowLevelZeroStrategy
|
||||
from .ddp import DDPStrategy
|
||||
|
||||
__all__ = [
|
||||
'Strategy', 'DDPStrategy',
|
||||
'LowLevelZeroStrategy', 'GeminiStrategy'
|
||||
]
|
||||
__all__ = ["Strategy", "DDPStrategy", "LowLevelZeroStrategy", "GeminiStrategy"]
|
||||
|
@ -19,7 +19,7 @@ _BoostArgSpec = Union[nn.Module, Tuple[nn.Module, Optimizer], Dict]
|
||||
|
||||
class Strategy(ABC):
|
||||
"""
|
||||
Base class for training strategies.
|
||||
Base class for training strategies.
|
||||
"""
|
||||
|
||||
def __init__(self, plugin_initializer: Callable[..., Optional[Plugin]] = lambda: None) -> None:
|
||||
@ -83,16 +83,18 @@ class Strategy(ABC):
|
||||
rets.append((model, optimizer))
|
||||
elif isinstance(arg, Dict):
|
||||
model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg)
|
||||
boost_result = dict(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
dataloader=dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
boost_result = dict(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
dataloader=dataloader,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
# remove None values
|
||||
boost_result = {key: value for key, value in boost_result.items() if value is not None}
|
||||
rets.append(boost_result)
|
||||
else:
|
||||
raise RuntimeError(f'Type {type(arg)} is not supported')
|
||||
raise RuntimeError(f"Type {type(arg)} is not supported")
|
||||
|
||||
return rets[0] if len(rets) == 1 else rets
|
||||
|
||||
@ -125,11 +127,9 @@ class Strategy(ABC):
|
||||
return DistributedSampler(dataset, 1, 0)
|
||||
|
||||
@abstractmethod
|
||||
def save_pretrained(self,
|
||||
model: nn.Module,
|
||||
path: str,
|
||||
only_rank0: bool = True,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
def save_pretrained(
|
||||
self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
@ -42,27 +42,27 @@ class LowLevelZeroStrategy(DDPStrategy):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
stage: int = 2,
|
||||
precision: str = 'fp16',
|
||||
seed: int = 42,
|
||||
placement_policy: str = 'cuda',
|
||||
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
|
||||
overlap_communication: bool = True, # only for stage 1&2
|
||||
initial_scale: float = 2**16,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
min_scale: float = 1,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0
|
||||
) -> None:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stage: int = 2,
|
||||
precision: str = "fp16",
|
||||
seed: int = 42,
|
||||
placement_policy: str = "cuda",
|
||||
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
|
||||
overlap_communication: bool = True, # only for stage 1&2
|
||||
initial_scale: float = 2**16,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
min_scale: float = 1,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0,
|
||||
) -> None:
|
||||
assert stage in (1, 2), f'Unsupported stage "{stage}"'
|
||||
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
|
||||
assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"'
|
||||
assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"'
|
||||
assert precision in ("fp32", "fp16"), f'Unsupported precision "{precision}"'
|
||||
|
||||
plugin_initializer = lambda: LowLevelZeroPlugin(
|
||||
# zero_config
|
||||
@ -71,7 +71,7 @@ class LowLevelZeroStrategy(DDPStrategy):
|
||||
# zero_optim_config
|
||||
reduce_bucket_size_in_m=reduce_bucket_size,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=(placement_policy == 'cpu'),
|
||||
cpu_offload=(placement_policy == "cpu"),
|
||||
# optim_config
|
||||
initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
@ -81,14 +81,15 @@ class LowLevelZeroStrategy(DDPStrategy):
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type
|
||||
norm_type=norm_type,
|
||||
)
|
||||
|
||||
super().__init__(seed, plugin_initializer)
|
||||
|
||||
def _post_init(self) -> None:
|
||||
assert isinstance(self.plugin, LowLevelZeroPlugin), \
|
||||
f'{type(self).__name__}\'s plugin is not initialized properly.'
|
||||
assert isinstance(
|
||||
self.plugin, LowLevelZeroPlugin
|
||||
), f"{type(self).__name__}'s plugin is not initialized properly."
|
||||
|
||||
def setup_distributed(self) -> None:
|
||||
colossalai.launch_from_torch({}, seed=self.seed)
|
||||
@ -131,45 +132,45 @@ class GeminiStrategy(DDPStrategy):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
seed: int = 42,
|
||||
shard_init: bool = False, # only for stage 3
|
||||
placement_policy: str = 'cuda',
|
||||
pin_memory: bool = True, # only for stage 3
|
||||
force_outputs_fp32: bool = False, # only for stage 3
|
||||
search_range_m: int = 32, # only for stage 3
|
||||
hidden_dim: Optional[int] = None, # only for stage 3
|
||||
min_chunk_size_m: float = 32, # only for stage 3
|
||||
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
|
||||
initial_scale: float = 2**16,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
min_scale: float = 1,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0
|
||||
) -> None:
|
||||
|
||||
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
|
||||
def __init__(
|
||||
self,
|
||||
seed: int = 42,
|
||||
shard_init: bool = False, # only for stage 3
|
||||
placement_policy: str = "cuda",
|
||||
pin_memory: bool = True, # only for stage 3
|
||||
force_outputs_fp32: bool = False, # only for stage 3
|
||||
search_range_m: int = 32, # only for stage 3
|
||||
hidden_dim: Optional[int] = None, # only for stage 3
|
||||
min_chunk_size_m: float = 32, # only for stage 3
|
||||
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
|
||||
initial_scale: float = 2**16,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
min_scale: float = 1,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0,
|
||||
) -> None:
|
||||
assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"'
|
||||
|
||||
# TODO(ver217): support shard_init when using from_pretrained()
|
||||
if shard_init:
|
||||
warnings.warn(
|
||||
f'Shard init is not supported model.from_pretrained() yet. '
|
||||
'Please load weights after strategy.prepare()'
|
||||
f"Shard init is not supported model.from_pretrained() yet. "
|
||||
"Please load weights after strategy.prepare()"
|
||||
)
|
||||
self.shard_init = shard_init
|
||||
|
||||
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
|
||||
warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.")
|
||||
|
||||
# NOTE: dist should be initialized before calling get_current_device()
|
||||
plugin_initializer = lambda: GeminiPlugin(
|
||||
# gemini_config
|
||||
device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
precision='fp16',
|
||||
precision="fp16",
|
||||
pin_memory=pin_memory,
|
||||
force_outputs_fp32=force_outputs_fp32,
|
||||
strict_ddp_mode=shard_init,
|
||||
@ -187,14 +188,13 @@ class GeminiStrategy(DDPStrategy):
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type
|
||||
norm_type=norm_type,
|
||||
)
|
||||
|
||||
super().__init__(seed, plugin_initializer)
|
||||
|
||||
def _post_init(self) -> None:
|
||||
assert isinstance(self.plugin, GeminiPlugin), \
|
||||
f'{type(self).__name__}\'s plugin is not initialized properly.'
|
||||
assert isinstance(self.plugin, GeminiPlugin), f"{type(self).__name__}'s plugin is not initialized properly."
|
||||
|
||||
def setup_distributed(self) -> None:
|
||||
colossalai.launch_from_torch({}, seed=self.seed)
|
||||
@ -203,10 +203,9 @@ class GeminiStrategy(DDPStrategy):
|
||||
world_size = dist.get_world_size()
|
||||
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
|
||||
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
|
||||
return ColoInitContext(device=get_current_device(),
|
||||
dtype=torch.half,
|
||||
default_pg=shard_pg,
|
||||
default_dist_spec=default_dist_spec)
|
||||
return ColoInitContext(
|
||||
device=get_current_device(), dtype=torch.half, default_pg=shard_pg, default_dist_spec=default_dist_spec
|
||||
)
|
||||
|
||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||
assert isinstance(model, GeminiModel)
|
||||
|
@ -31,24 +31,21 @@ def get_grad_required_state_dict(model: nn.Module):
|
||||
|
||||
class DDPStrategy(Strategy):
|
||||
"""
|
||||
Strategy for distributed training using torch.distributed.
|
||||
Strategy for distributed training using torch.distributed.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
seed: int = 42,
|
||||
plugin_initializer: Callable = TorchDDPPlugin
|
||||
) -> None:
|
||||
def __init__(self, seed: int = 42, plugin_initializer: Callable = TorchDDPPlugin) -> None:
|
||||
self.seed = seed
|
||||
super().__init__(plugin_initializer)
|
||||
|
||||
def _try_init_dist(self, force: bool = False) -> None:
|
||||
try:
|
||||
rank = int(os.environ['RANK'])
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
world_size = int(os.environ['WORLD_SIZE'])
|
||||
host = os.environ['MASTER_ADDR']
|
||||
port = int(os.environ['MASTER_PORT'])
|
||||
dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
|
||||
rank = int(os.environ["RANK"])
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
host = os.environ["MASTER_ADDR"]
|
||||
port = int(os.environ["MASTER_PORT"])
|
||||
dist.init_process_group("nccl", init_method=f"tcp://[{host}]:{port}", world_size=world_size, rank=rank)
|
||||
torch.cuda.set_device(local_rank)
|
||||
except KeyError as e:
|
||||
if force:
|
||||
@ -60,8 +57,7 @@ class DDPStrategy(Strategy):
|
||||
raise e
|
||||
|
||||
def _post_init(self) -> None:
|
||||
assert isinstance(self.plugin, TorchDDPPlugin), \
|
||||
f'{type(self).__name__}\'s plugin is not initialized properly.'
|
||||
assert isinstance(self.plugin, TorchDDPPlugin), f"{type(self).__name__}'s plugin is not initialized properly."
|
||||
|
||||
def setup_distributed(self) -> None:
|
||||
self._try_init_dist(force=True)
|
||||
@ -73,12 +69,14 @@ class DDPStrategy(Strategy):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
|
||||
return self.plugin.prepare_dataloader(data_buffer,
|
||||
batch_size=data_buffer.sample_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=data_buffer.collate_fn)
|
||||
return self.plugin.prepare_dataloader(
|
||||
data_buffer,
|
||||
batch_size=data_buffer.sample_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=data_buffer.collate_fn,
|
||||
)
|
||||
|
||||
def setup_sampler(self, dataset) -> DistributedSampler:
|
||||
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
|
||||
@ -88,11 +86,9 @@ class DDPStrategy(Strategy):
|
||||
assert isinstance(model, TorchDDPModel), "model is not wrapped by TorchDDPModel."
|
||||
return model.unwrap()
|
||||
|
||||
def save_pretrained(self,
|
||||
model: nn.Module,
|
||||
path: str,
|
||||
only_rank0: bool = True,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
def save_pretrained(
|
||||
self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None
|
||||
) -> None:
|
||||
if not only_rank0 or dist.get_rank() == 0:
|
||||
unwrapped_model = self.unwrap_model(model)
|
||||
assert isinstance(unwrapped_model, (Actor, Critic, RewardModel))
|
||||
@ -103,17 +99,11 @@ class DDPStrategy(Strategy):
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
model_path = os.path.join(path, "pytorch_model.bin")
|
||||
self.save_model(model,
|
||||
model_path,
|
||||
only_rank0=only_rank0)
|
||||
self.save_model(model, model_path, only_rank0=only_rank0)
|
||||
|
||||
def _replace_keys(model_path: str,
|
||||
replace_fn: Callable):
|
||||
def _replace_keys(model_path: str, replace_fn: Callable):
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
state_dict = {
|
||||
replace_fn(k): v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
state_dict = {replace_fn(k): v for k, v in state_dict.items()}
|
||||
torch.save(state_dict, model_path)
|
||||
|
||||
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
|
||||
@ -124,13 +114,13 @@ class DDPStrategy(Strategy):
|
||||
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
||||
# TODO: implement sharding on naive strategy
|
||||
model = self.unwrap_model(model)
|
||||
if 'requires_grad_only' in config and config['requires_grad_only'] == True:
|
||||
if "requires_grad_only" in config and config["requires_grad_only"] == True:
|
||||
state_dict = get_grad_required_state_dict(model)
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
|
||||
if 'shard_size' in config:
|
||||
shard_size = config['shard_size']
|
||||
if "shard_size" in config:
|
||||
shard_size = config["shard_size"]
|
||||
accumulate_size = 0
|
||||
state_dict_shard = OrderedDict()
|
||||
for name, param in state_dict.items():
|
||||
|
@ -4,7 +4,6 @@ import numpy as np
|
||||
|
||||
|
||||
class DistributedSampler:
|
||||
|
||||
def __init__(self, dataset, num_replicas: int, rank: int) -> None:
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
@ -12,7 +11,7 @@ class DistributedSampler:
|
||||
|
||||
if len(self.dataset) % self.num_replicas != 0:
|
||||
self.num_samples = math.ceil(
|
||||
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
|
||||
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
|
||||
@ -20,10 +19,10 @@ class DistributedSampler:
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
|
||||
indices = list(range(len(self.dataset)))
|
||||
indices = indices[:self.total_size]
|
||||
indices = indices[: self.total_size]
|
||||
assert len(indices) == self.total_size
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
self.indices = indices
|
||||
|
||||
|
@ -42,7 +42,6 @@ def is_rank_0() -> bool:
|
||||
|
||||
|
||||
def to_device(x: Any, device: torch.device) -> Any:
|
||||
|
||||
def _to(t: Any):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return t.to(device)
|
||||
|
@ -1,5 +1,4 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import openai
|
||||
@ -9,7 +8,8 @@ from utils import jload
|
||||
|
||||
def main(args):
|
||||
assert len(args.answer_file_list) == len(
|
||||
args.model_name_list), "The number of answer files and model names should be equal!"
|
||||
args.model_name_list
|
||||
), "The number of answer files and model names should be equal!"
|
||||
|
||||
# load config
|
||||
config = jload(args.config_file)
|
||||
@ -36,7 +36,8 @@ def main(args):
|
||||
|
||||
if len(args.model_name_list) == 1 and not gpt_evaluation_prompt:
|
||||
raise Exception(
|
||||
"No prompt file for gpt evaluation provided. Please specify the prompt file for gpt evaluation!")
|
||||
"No prompt file for gpt evaluation provided. Please specify the prompt file for gpt evaluation!"
|
||||
)
|
||||
|
||||
if args.gpt_model == "text-davinci-003" and args.gpt_with_reference:
|
||||
raise Exception(
|
||||
@ -44,8 +45,15 @@ def main(args):
|
||||
)
|
||||
|
||||
# initialize evaluator
|
||||
evaluator = Evaluator(metrics_per_category, battle_prompt, gpt_evaluation_prompt, args.gpt_model,
|
||||
config["language"], config.get("path_for_UniEval", None), args.gpt_with_reference)
|
||||
evaluator = Evaluator(
|
||||
metrics_per_category,
|
||||
battle_prompt,
|
||||
gpt_evaluation_prompt,
|
||||
args.gpt_model,
|
||||
config["language"],
|
||||
config.get("path_for_UniEval", None),
|
||||
args.gpt_with_reference,
|
||||
)
|
||||
if len(args.model_name_list) == 2:
|
||||
answers1 = jload(args.answer_file_list[0])
|
||||
answers2 = jload(args.answer_file_list[1])
|
||||
@ -68,41 +76,41 @@ def main(args):
|
||||
raise ValueError(f'Unsupported language {config["language"]}!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='ColossalAI LLM evaluation pipeline.')
|
||||
parser.add_argument('--config_file',
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help='path to the file of target results')
|
||||
parser.add_argument('--battle_prompt_file', type=str, default=None, help='path to the prompt file for battle')
|
||||
parser.add_argument('--gpt_evaluation_prompt_file',
|
||||
type=str,
|
||||
default=None,
|
||||
help='path to the prompt file for gpt evaluation')
|
||||
parser.add_argument('--target_file', type=str, default=None, help='path to the target answer (ground truth) file')
|
||||
parser.add_argument('--answer_file_list',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=[],
|
||||
required=True,
|
||||
help='path to the answer files of at most 2 models')
|
||||
parser.add_argument('--model_name_list',
|
||||
type=str,
|
||||
nargs='+',
|
||||
default=[],
|
||||
required=True,
|
||||
help='the names of at most 2 models')
|
||||
parser.add_argument('--gpt_model',
|
||||
default="gpt-3.5-turbo",
|
||||
choices=["text-davinci-003", "gpt-3.5-turbo", "gpt-4"],
|
||||
help='which GPT model to use for evaluation')
|
||||
parser.add_argument('--gpt_with_reference',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help='whether to include reference answer in gpt evaluation')
|
||||
parser.add_argument('--save_path', type=str, default="results", help='path to save evaluation results')
|
||||
parser.add_argument('--openai_key', type=str, default=None, required=True, help='Your openai key')
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="ColossalAI LLM evaluation pipeline.")
|
||||
parser.add_argument(
|
||||
"--config_file", type=str, default=None, required=True, help="path to the file of target results"
|
||||
)
|
||||
parser.add_argument("--battle_prompt_file", type=str, default=None, help="path to the prompt file for battle")
|
||||
parser.add_argument(
|
||||
"--gpt_evaluation_prompt_file", type=str, default=None, help="path to the prompt file for gpt evaluation"
|
||||
)
|
||||
parser.add_argument("--target_file", type=str, default=None, help="path to the target answer (ground truth) file")
|
||||
parser.add_argument(
|
||||
"--answer_file_list",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=[],
|
||||
required=True,
|
||||
help="path to the answer files of at most 2 models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_list", type=str, nargs="+", default=[], required=True, help="the names of at most 2 models"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpt_model",
|
||||
default="gpt-3.5-turbo",
|
||||
choices=["text-davinci-003", "gpt-3.5-turbo", "gpt-4"],
|
||||
help="which GPT model to use for evaluation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpt_with_reference",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether to include reference answer in gpt evaluation",
|
||||
)
|
||||
parser.add_argument("--save_path", type=str, default="results", help="path to save evaluation results")
|
||||
parser.add_argument("--openai_key", type=str, default=None, required=True, help="Your openai key")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.openai_key is not None:
|
||||
|
@ -3,20 +3,27 @@ from typing import Any, Dict, List
|
||||
|
||||
import gpt_evaluate
|
||||
import metrics
|
||||
import pandas as pd
|
||||
import unieval
|
||||
from utils import analyze_automatic_results, get_data_per_category, save_automatic_results
|
||||
|
||||
|
||||
class Evaluator(object):
|
||||
"""
|
||||
A class named Evaluator includes GPT-3.5/GPT-4 evaluation
|
||||
and automatic evaluation
|
||||
A class named Evaluator includes GPT-3.5/GPT-4 evaluation
|
||||
and automatic evaluation
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, params: Dict[str, Any], battle_prompt: Dict[str, Any], gpt_evaluation_prompt: Dict[str, Any],
|
||||
gpt_model: str, language: str, path_for_UniEval: Dict[str, str], gpt_with_reference: bool) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
battle_prompt: Dict[str, Any],
|
||||
gpt_evaluation_prompt: Dict[str, Any],
|
||||
gpt_model: str,
|
||||
language: str,
|
||||
path_for_UniEval: Dict[str, str],
|
||||
gpt_with_reference: bool,
|
||||
) -> None:
|
||||
self.params = params
|
||||
self.battle_prompt = battle_prompt
|
||||
self.gpt_evaluation_prompt = gpt_evaluation_prompt
|
||||
@ -103,7 +110,8 @@ class Evaluator(object):
|
||||
|
||||
if self.params[category]["UniEval"] and self.language == "cn":
|
||||
raise Exception(
|
||||
"UniEval doesn't support Chinese! Please remove UniEval config in your Chinese config file.")
|
||||
"UniEval doesn't support Chinese! Please remove UniEval config in your Chinese config file."
|
||||
)
|
||||
|
||||
category_metrics = self.params[category]["UniEval"]
|
||||
|
||||
@ -134,10 +142,9 @@ class Evaluator(object):
|
||||
sources_list = [answer["instruction"] + answer["input"] for answer in answers_per_category[category]]
|
||||
|
||||
data = unieval.convert_data_to_unieval_format(predicts_list, sources_list, targets_list)
|
||||
scores = uni_evaluator.evaluate(data,
|
||||
category,
|
||||
dims=list(self.unieval_metric_stats[task][category].keys()),
|
||||
overall=False)
|
||||
scores = uni_evaluator.evaluate(
|
||||
data, category, dims=list(self.unieval_metric_stats[task][category].keys()), overall=False
|
||||
)
|
||||
avg_scores = unieval.calculate_average_score(scores)
|
||||
|
||||
self.unieval_metric_stats[task][category].update(avg_scores)
|
||||
@ -165,7 +172,8 @@ class Evaluator(object):
|
||||
category,
|
||||
self.gpt_model,
|
||||
self.language,
|
||||
references=targets_per_category[category] if self.gpt_with_reference else None)
|
||||
references=targets_per_category[category] if self.gpt_with_reference else None,
|
||||
)
|
||||
|
||||
def save(self, path: str, model_name_list: List[str]) -> None:
|
||||
"""
|
||||
@ -204,16 +212,18 @@ class Evaluator(object):
|
||||
gpt_base_save_path = os.path.join(path, "gpt_evaluate", "gpt_evaluate_results")
|
||||
gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results")
|
||||
|
||||
all_evaluations = gpt_evaluate.save_gpt_evaluation_results(model_name_list[0],
|
||||
self.gpt_evaluation_results,
|
||||
gpt_evaluation_results_save_path)
|
||||
all_evaluations = gpt_evaluate.save_gpt_evaluation_results(
|
||||
model_name_list[0], self.gpt_evaluation_results, gpt_evaluation_results_save_path
|
||||
)
|
||||
|
||||
# Start to calculate scores and save statistics.
|
||||
gpt_evaluation_statistics_save_path = os.path.join(gpt_base_save_path, "evaluation_statistics")
|
||||
gpt_evaluate.save_gpt_evaluation_statistics(model_name_list[0], all_evaluations,
|
||||
gpt_evaluation_statistics_save_path)
|
||||
gpt_evaluate.save_gpt_evaluation_statistics(
|
||||
model_name_list[0], all_evaluations, gpt_evaluation_statistics_save_path
|
||||
)
|
||||
|
||||
# Save charts and csv.
|
||||
gpt_evaluation_analyses_save_path = os.path.join(gpt_base_save_path, "evaluation_analyses")
|
||||
gpt_evaluate.analyze_gpt_evaluation_statistics(gpt_evaluation_statistics_save_path,
|
||||
gpt_evaluation_analyses_save_path)
|
||||
gpt_evaluate.analyze_gpt_evaluation_statistics(
|
||||
gpt_evaluation_statistics_save_path, gpt_evaluation_analyses_save_path
|
||||
)
|
||||
|
@ -14,20 +14,18 @@ import tqdm
|
||||
from utils import jdump, jload
|
||||
|
||||
ref_step_template = {
|
||||
"en":
|
||||
"Now please compare the answer with the {adjective} answer, determine whether the answer is able to achieve the same level of {metric}.\n\n",
|
||||
"cn":
|
||||
"请比较答案与上面的{adjective}答案,确定答案是否可以达到与该{adjective}答案同样水平的{metric}。\n\n"
|
||||
"en": "Now please compare the answer with the {adjective} answer, determine whether the answer is able to achieve the same level of {metric}.\n\n",
|
||||
"cn": "请比较答案与上面的{adjective}答案,确定答案是否可以达到与该{adjective}答案同样水平的{metric}。\n\n",
|
||||
}
|
||||
|
||||
ref_answer_template_general = {
|
||||
"en": "\nAn example answer with good quality is as follows:\n\n{answer}\n\n",
|
||||
"cn": "\n一个优质的示例答案如下:\n\n{answer}\n\n"
|
||||
"cn": "\n一个优质的示例答案如下:\n\n{answer}\n\n",
|
||||
}
|
||||
|
||||
ref_answer_template_correctness = {
|
||||
"en": "\nA correct answer is as follows:\n\n{answer}\n\n",
|
||||
"cn": "\n标准答案如下:\n\n{answer}\n\n"
|
||||
"cn": "\n标准答案如下:\n\n{answer}\n\n",
|
||||
}
|
||||
|
||||
|
||||
@ -51,10 +49,7 @@ def get_battle_result(sys_prompt: str, user_prompt: str, id: int, max_tokens: in
|
||||
response = openai.ChatCompletion.create(
|
||||
model="gpt-4",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": sys_prompt
|
||||
},
|
||||
{"role": "system", "content": sys_prompt},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
@ -106,7 +101,7 @@ def parse_battle_score(evaluation: str) -> List[float]:
|
||||
return [float(sp[0]), float(sp[1])]
|
||||
else:
|
||||
raise Exception(f"Invalid score pair. Got {evaluation}.")
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return [-1, -1]
|
||||
|
||||
|
||||
@ -125,9 +120,6 @@ def battle(answer1: List[Dict], answer2: List[Dict], prompt_dict: Dict[str, Any]
|
||||
|
||||
assert len(answer1) == len(answer2)
|
||||
|
||||
handles = []
|
||||
evaluation_file = []
|
||||
|
||||
total_len = len(answer1)
|
||||
question_idx_list = list(range(total_len))
|
||||
|
||||
@ -140,9 +132,12 @@ def battle(answer1: List[Dict], answer2: List[Dict], prompt_dict: Dict[str, Any]
|
||||
assert answer1[i]["id"] == answer2[i]["id"]
|
||||
answer_id = answer1[i]["id"]
|
||||
|
||||
ques = answer1[i]["instruction"] if answer1[i][
|
||||
"input"] == "" else answer1[i]["instruction"] + " " + answer1[i]["input"]
|
||||
cat = answer1[i]["category"]
|
||||
ques = (
|
||||
answer1[i]["instruction"]
|
||||
if answer1[i]["input"] == ""
|
||||
else answer1[i]["instruction"] + " " + answer1[i]["input"]
|
||||
)
|
||||
answer1[i]["category"]
|
||||
ans1 = answer1[i]["output"]
|
||||
ans2 = answer2[i]["output"]
|
||||
|
||||
@ -267,7 +262,11 @@ def reference_template(metric: str, language: str, reference: Dict[str, Any]) ->
|
||||
|
||||
step_to_add = ref_step_template[language]
|
||||
|
||||
for_the_given_answer = "{metric} (1-5) (directly give the score for the given answer):" if language == "en" else "{metric} (1-5) (直接对给定答案打分)"
|
||||
for_the_given_answer = (
|
||||
"{metric} (1-5) (directly give the score for the given answer):"
|
||||
if language == "en"
|
||||
else "{metric} (1-5) (直接对给定答案打分)"
|
||||
)
|
||||
|
||||
# adjective is used to describe the word "answer" in the prompt.
|
||||
adjective = "example" if language == "en" else "示例"
|
||||
@ -280,8 +279,9 @@ def reference_template(metric: str, language: str, reference: Dict[str, Any]) ->
|
||||
answer_to_add = ref_answer_template_correctness[language]
|
||||
|
||||
answer_to_add = answer_to_add.format(answer=reference["target"] if reference["target"] else reference["output"])
|
||||
step_to_add = step_to_add.format(metric=metric.lower(),
|
||||
adjective=adjective) + for_the_given_answer.format(metric=metric)
|
||||
step_to_add = step_to_add.format(metric=metric.lower(), adjective=adjective) + for_the_given_answer.format(
|
||||
metric=metric
|
||||
)
|
||||
|
||||
return answer_to_add + step_to_add
|
||||
|
||||
@ -329,7 +329,8 @@ def multiturn_chat_completion(user_messages: List[str], model: str, max_tokens:
|
||||
for j in range(i):
|
||||
messages_to_send.append(fill_in_message("user", user_messages[j]))
|
||||
messages_to_send.append(
|
||||
fill_in_message("assistant", assistant_responses[j]["choices"][0]["message"]["content"]))
|
||||
fill_in_message("assistant", assistant_responses[j]["choices"][0]["message"]["content"])
|
||||
)
|
||||
|
||||
# Length of user messages == Length of assistant messages + 1
|
||||
# Because we always expect the api to response
|
||||
@ -351,13 +352,15 @@ def multiturn_chat_completion(user_messages: List[str], model: str, max_tokens:
|
||||
return assistant_responses[-1]
|
||||
|
||||
|
||||
def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
|
||||
inst: Dict[str, Any],
|
||||
metrics: List[str],
|
||||
language: str,
|
||||
reference: Dict[str, Any] = None,
|
||||
model: str = "gpt-3.5-turbo",
|
||||
max_tokens: int = 2048) -> Dict[str, Any]:
|
||||
def get_gpt_evaluation_without_logprobs(
|
||||
prompt: Dict[str, Any],
|
||||
inst: Dict[str, Any],
|
||||
metrics: List[str],
|
||||
language: str,
|
||||
reference: Dict[str, Any] = None,
|
||||
model: str = "gpt-3.5-turbo",
|
||||
max_tokens: int = 2048,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Use chat models(gpt-3.5-turbo or gpt-4) to evaluate one model answer.
|
||||
|
||||
@ -378,7 +381,7 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
|
||||
|
||||
MAX_API_RETRY = 3
|
||||
|
||||
question = (inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"])
|
||||
question = inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"]
|
||||
answer = inst["output"]
|
||||
inst["evaluation"] = {}
|
||||
|
||||
@ -400,10 +403,9 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
|
||||
|
||||
if prompt_reference:
|
||||
# Do a 2-round conversation
|
||||
response = multiturn_chat_completion([prompt_1st_round, prompt_reference],
|
||||
model,
|
||||
max_tokens=max_tokens,
|
||||
turns=2)
|
||||
response = multiturn_chat_completion(
|
||||
[prompt_1st_round, prompt_reference], model, max_tokens=max_tokens, turns=2
|
||||
)
|
||||
else:
|
||||
response = multiturn_chat_completion([prompt_1st_round], model, max_tokens=max_tokens, turns=1)
|
||||
|
||||
@ -427,10 +429,9 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
|
||||
return inst
|
||||
|
||||
|
||||
def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any],
|
||||
inst: Dict[str, Any],
|
||||
metrics: List[str],
|
||||
max_tokens: int = 2048) -> Dict[str, Any]:
|
||||
def get_gpt_evaluation_with_logprobs(
|
||||
prompt: Dict[str, Any], inst: Dict[str, Any], metrics: List[str], max_tokens: int = 2048
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Use completion model(text-davinci-003) to evaluate one model answer.
|
||||
Only completion models can return log probabilities.
|
||||
@ -449,7 +450,7 @@ def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any],
|
||||
|
||||
MAX_API_RETRY = 3
|
||||
|
||||
question = (inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"])
|
||||
question = inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"]
|
||||
answer = inst["output"]
|
||||
inst["evaluation"] = {}
|
||||
|
||||
@ -492,13 +493,15 @@ def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any],
|
||||
return inst
|
||||
|
||||
|
||||
def evaluate(answers: List[Dict],
|
||||
prompt: Dict[str, Any],
|
||||
metrics: List[str],
|
||||
category: str,
|
||||
model: str,
|
||||
language: str,
|
||||
references: List[Dict] = None) -> List[Dict]:
|
||||
def evaluate(
|
||||
answers: List[Dict],
|
||||
prompt: Dict[str, Any],
|
||||
metrics: List[str],
|
||||
category: str,
|
||||
model: str,
|
||||
language: str,
|
||||
references: List[Dict] = None,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Use GPT models to evaluate model answers and save evaluation results.
|
||||
|
||||
@ -529,21 +532,23 @@ def evaluate(answers: List[Dict],
|
||||
if model == "text-davinci-003":
|
||||
future = executor.submit(get_gpt_evaluation_with_logprobs, prompt, inst, metrics, 1)
|
||||
else:
|
||||
future = executor.submit(get_gpt_evaluation_without_logprobs,
|
||||
prompt,
|
||||
inst,
|
||||
metrics,
|
||||
language,
|
||||
reference=None if references is None else references[idx],
|
||||
model=model,
|
||||
max_tokens=1)
|
||||
future = executor.submit(
|
||||
get_gpt_evaluation_without_logprobs,
|
||||
prompt,
|
||||
inst,
|
||||
metrics,
|
||||
language,
|
||||
reference=None if references is None else references[idx],
|
||||
model=model,
|
||||
max_tokens=1,
|
||||
)
|
||||
|
||||
futures.append(future)
|
||||
|
||||
for future in tqdm.tqdm(
|
||||
concurrent.futures.as_completed(futures),
|
||||
desc=f"{category}: ",
|
||||
total=len(futures),
|
||||
concurrent.futures.as_completed(futures),
|
||||
desc=f"{category}: ",
|
||||
total=len(futures),
|
||||
):
|
||||
evaluations.append(future.result())
|
||||
|
||||
@ -610,12 +615,13 @@ def calculate_scores_form_response(response: str, evaluation: Dict[str, Any]) ->
|
||||
return int(results[0])
|
||||
else:
|
||||
raise Exception(f"Invalid score pair. Got {evaluation}.")
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def save_gpt_evaluation_results(model_name: str, gpt_evaluation_results: Dict[str, Any],
|
||||
save_path: str) -> Dict[str, Any]:
|
||||
def save_gpt_evaluation_results(
|
||||
model_name: str, gpt_evaluation_results: Dict[str, Any], save_path: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Save evaluation results for different categories for one model.
|
||||
|
||||
@ -667,10 +673,12 @@ def save_gpt_evaluation_statistics(model_name: str, evaluations: List[Dict], sav
|
||||
scores[metric].append(0)
|
||||
elif evaluation["evaluation"][metric]["logprobs"] is not None:
|
||||
scores[metric].append(
|
||||
calculate_scores_form_logprobs(evaluation["evaluation"][metric]["logprobs"][0]))
|
||||
calculate_scores_form_logprobs(evaluation["evaluation"][metric]["logprobs"][0])
|
||||
)
|
||||
else:
|
||||
scores[metric].append(
|
||||
calculate_scores_form_response(evaluation["evaluation"][metric]["response"], evaluation))
|
||||
calculate_scores_form_response(evaluation["evaluation"][metric]["response"], evaluation)
|
||||
)
|
||||
|
||||
statistics = {}
|
||||
for metric in metrics:
|
||||
@ -751,9 +759,9 @@ def analyze_gpt_evaluation_statistics(statistics_path: str, save_path: str) -> N
|
||||
frame_all.to_csv(os.path.join(save_path, "gpt_evaluation_statistics.csv"))
|
||||
|
||||
for category in tqdm.tqdm(
|
||||
frame_per_category.keys(),
|
||||
desc=f"GPT evaluation: ",
|
||||
total=len(frame_per_category.keys()),
|
||||
frame_per_category.keys(),
|
||||
desc=f"GPT evaluation: ",
|
||||
total=len(frame_per_category.keys()),
|
||||
):
|
||||
data = pd.DataFrame(frame_per_category[category])
|
||||
|
||||
|
@ -21,13 +21,17 @@ def bleu_score(preds: List[str], targets: List[str], language: str) -> Dict[str,
|
||||
"""
|
||||
bleu_scores = {"bleu1": 0, "bleu2": 0, "bleu3": 0, "bleu4": 0}
|
||||
cumulative_bleu = [0] * 4
|
||||
weights = [(1. / 1., 0., 0., 0.), (1. / 2., 1. / 2., 0., 0.), (1. / 3., 1. / 3., 1. / 3., 0.),
|
||||
(1. / 4., 1. / 4., 1. / 4., 1. / 4.)]
|
||||
weights = [
|
||||
(1.0 / 1.0, 0.0, 0.0, 0.0),
|
||||
(1.0 / 2.0, 1.0 / 2.0, 0.0, 0.0),
|
||||
(1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0, 0.0),
|
||||
(1.0 / 4.0, 1.0 / 4.0, 1.0 / 4.0, 1.0 / 4.0),
|
||||
]
|
||||
|
||||
for pred, target in zip(preds, targets):
|
||||
if language == "cn":
|
||||
pred_list = ' '.join(jieba.cut(preprocessing_text(pred))).split()
|
||||
target_list = [(' '.join(jieba.cut(preprocessing_text(target)))).split()]
|
||||
pred_list = " ".join(jieba.cut(preprocessing_text(pred))).split()
|
||||
target_list = [(" ".join(jieba.cut(preprocessing_text(target)))).split()]
|
||||
elif language == "en":
|
||||
pred_list = preprocessing_text(pred).split()
|
||||
target_list = [preprocessing_text(target).split()]
|
||||
@ -42,15 +46,14 @@ def bleu_score(preds: List[str], targets: List[str], language: str) -> Dict[str,
|
||||
|
||||
|
||||
def chrf_score(preds: List[str], targets: List[str], language: str) -> Dict[str, float]:
|
||||
"""Calculate CHRF Score Metric in sentence level.
|
||||
"""
|
||||
"""Calculate CHRF Score Metric in sentence level."""
|
||||
chrf_score = {"chrf": 0}
|
||||
cumulative_chrf = []
|
||||
|
||||
for pred, target in zip(preds, targets):
|
||||
if language == "cn":
|
||||
pred_list = ' '.join(jieba.cut(preprocessing_text(pred))).split()
|
||||
target_list = ' '.join(jieba.cut(preprocessing_text(target))).split()
|
||||
pred_list = " ".join(jieba.cut(preprocessing_text(pred))).split()
|
||||
target_list = " ".join(jieba.cut(preprocessing_text(target))).split()
|
||||
elif language == "en":
|
||||
pred_list = preprocessing_text(pred).split()
|
||||
target_list = preprocessing_text(target).split()
|
||||
@ -75,8 +78,8 @@ def rouge_cn_score(preds: List[str], targets: List[str]) -> Dict[str, float]:
|
||||
all_targets = []
|
||||
|
||||
for pred, target in zip(preds, targets):
|
||||
pred_list = remove_redundant_space(' '.join(jieba.cut(preprocessing_text(pred))))
|
||||
target_list = remove_redundant_space(' '.join(jieba.cut(preprocessing_text(target))))
|
||||
pred_list = remove_redundant_space(" ".join(jieba.cut(preprocessing_text(pred))))
|
||||
target_list = remove_redundant_space(" ".join(jieba.cut(preprocessing_text(target))))
|
||||
all_preds.append(pred_list)
|
||||
all_targets.append(target_list)
|
||||
|
||||
@ -99,16 +102,14 @@ def rouge_en_score(preds: List[str], targets: List[str]) -> Dict[str, float]:
|
||||
longest common subsequence (LCS) between preds and targets.
|
||||
"""
|
||||
rouge_scores = {"rouge1": 0, "rouge2": 0, "rougeL": 0}
|
||||
all_preds = []
|
||||
all_targets = []
|
||||
|
||||
rouge_en = Rouge_en.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=False)
|
||||
|
||||
for pred, target in zip(preds, targets):
|
||||
score = rouge_en.score(preprocessing_text(pred), preprocessing_text(target))
|
||||
rouge_scores["rouge1"] += score['rouge1'].fmeasure
|
||||
rouge_scores["rouge2"] += score['rouge2'].fmeasure
|
||||
rouge_scores["rougeL"] += score['rougeL'].fmeasure
|
||||
rouge_scores["rouge1"] += score["rouge1"].fmeasure
|
||||
rouge_scores["rouge2"] += score["rouge2"].fmeasure
|
||||
rouge_scores["rougeL"] += score["rougeL"].fmeasure
|
||||
|
||||
rouge_scores["rouge1"] = rouge_scores["rouge1"] / len(preds)
|
||||
rouge_scores["rouge2"] = rouge_scores["rouge2"] / len(preds)
|
||||
@ -137,7 +138,7 @@ def distinct_score(preds: List[str], language: str) -> Dict[str, float]:
|
||||
|
||||
for pred in preds:
|
||||
if language == "cn":
|
||||
pred_seg_list = ' '.join(jieba.cut(pred)).split()
|
||||
pred_seg_list = " ".join(jieba.cut(pred)).split()
|
||||
count_segs = len(pred_seg_list)
|
||||
unique_segs = set(pred_seg_list)
|
||||
count_unique_chars = len(unique_segs)
|
||||
@ -151,7 +152,7 @@ def distinct_score(preds: List[str], language: str) -> Dict[str, float]:
|
||||
split_pred = preprocessing_text(pred).split()
|
||||
for n in range(0, 3):
|
||||
for i in range(0, len(split_pred) - n):
|
||||
ngram = ' '.join(split_pred[i:i + n + 1])
|
||||
ngram = " ".join(split_pred[i : i + n + 1])
|
||||
unique_ngram[n].add(ngram)
|
||||
all_ngram_count[n] += 1
|
||||
|
||||
@ -203,8 +204,8 @@ def calculate_precision_recall_f1(preds: List[str], targets: List[str], language
|
||||
|
||||
for pred, target in zip(preds, targets):
|
||||
if language == "cn":
|
||||
pred_list = [char for char in ' '.join(jieba.cut(preprocessing_text(pred))).split()]
|
||||
target_list = [char for char in ' '.join(jieba.cut(preprocessing_text(target))).split()]
|
||||
pred_list = [char for char in " ".join(jieba.cut(preprocessing_text(pred))).split()]
|
||||
target_list = [char for char in " ".join(jieba.cut(preprocessing_text(target))).split()]
|
||||
elif language == "en":
|
||||
pred_list = [char for char in preprocessing_text(pred).split()]
|
||||
target_list = [char for char in preprocessing_text(target).split()]
|
||||
|
@ -7,6 +7,9 @@ from .utils import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'get_evaluator', 'convert_data_to_unieval_format', 'calculate_average_score', 'save_unieval_results',
|
||||
'analyze_unieval_results'
|
||||
"get_evaluator",
|
||||
"convert_data_to_unieval_format",
|
||||
"calculate_average_score",
|
||||
"save_unieval_results",
|
||||
"analyze_unieval_results",
|
||||
]
|
||||
|
@ -28,29 +28,29 @@ from .utils import add_question
|
||||
|
||||
|
||||
class SumEvaluator:
|
||||
|
||||
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
|
||||
""" Set up evaluator for text summarization """
|
||||
def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
|
||||
"""Set up evaluator for text summarization"""
|
||||
self.scorer = UniEvaluator(
|
||||
model_name_or_path='MingZhong/unieval-sum' if model_name_or_path == "" else model_name_or_path,
|
||||
model_name_or_path="MingZhong/unieval-sum" if model_name_or_path == "" else model_name_or_path,
|
||||
max_length=max_length,
|
||||
device=device,
|
||||
cache_dir=cache_dir)
|
||||
self.task = 'summarization'
|
||||
self.dimensions = ['coherence', 'consistency', 'fluency', 'relevance']
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
self.task = "summarization"
|
||||
self.dimensions = ["coherence", "consistency", "fluency", "relevance"]
|
||||
|
||||
def evaluate(self, data, category, dims=None, overall=True):
|
||||
"""
|
||||
Get the scores of all the given dimensions
|
||||
Get the scores of all the given dimensions
|
||||
|
||||
category: The category to be evaluated.
|
||||
category: The category to be evaluated.
|
||||
|
||||
dims: A list of dimensions to be evaluated. If dims is None, SumEvaluator will evaluate
|
||||
four dimensions: coherence, consistency, fluency, relevance.
|
||||
dims: A list of dimensions to be evaluated. If dims is None, SumEvaluator will evaluate
|
||||
four dimensions: coherence, consistency, fluency, relevance.
|
||||
|
||||
overall: indicates whether the overall score is to be calculated.
|
||||
Overall score can be customized to a combination of scores based on different
|
||||
dimensions. The default here is the average score of all the given dimensions.
|
||||
overall: indicates whether the overall score is to be calculated.
|
||||
Overall score can be customized to a combination of scores based on different
|
||||
dimensions. The default here is the average score of all the given dimensions.
|
||||
"""
|
||||
n_data = len(data)
|
||||
eval_scores = [{} for _ in range(n_data)]
|
||||
@ -63,12 +63,12 @@ class SumEvaluator:
|
||||
|
||||
for dim in eval_dims:
|
||||
# Calculate average sentence-level scores for 'consistency' and 'fluency'
|
||||
if dim == 'consistency' or dim == 'fluency':
|
||||
if dim == "consistency" or dim == "fluency":
|
||||
src_list, output_list = [], []
|
||||
n_sents = [] # the number of sentences in each generated summary
|
||||
n_sents = [] # the number of sentences in each generated summary
|
||||
for i in range(n_data):
|
||||
source = data[i]['source']
|
||||
system_outputs = sent_tokenize(data[i]['system_output'])
|
||||
source = data[i]["source"]
|
||||
system_outputs = sent_tokenize(data[i]["system_output"])
|
||||
n_sents.append(len(system_outputs))
|
||||
for j in range(len(system_outputs)):
|
||||
src_list.append(source)
|
||||
@ -81,24 +81,26 @@ class SumEvaluator:
|
||||
score = []
|
||||
for cur_n_sent in n_sents:
|
||||
# prevent denominator from being 0
|
||||
score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]) / (cur_n_sent + 1e-6))
|
||||
score.append(sum(sent_score[start_idx : start_idx + cur_n_sent]) / (cur_n_sent + 1e-6))
|
||||
start_idx += cur_n_sent
|
||||
|
||||
# Calculate summary-level score for 'coherence' and 'relevance'
|
||||
elif dim == 'coherence' or dim == 'relevance':
|
||||
elif dim == "coherence" or dim == "relevance":
|
||||
src_list, output_list, ref_list = [], [], []
|
||||
for i in range(n_data):
|
||||
src_list.append(data[i]['source'])
|
||||
output_list.append(data[i]['system_output'])
|
||||
if dim == 'relevance':
|
||||
ref_list.append(data[i]['reference'])
|
||||
src_list.append(data[i]["source"])
|
||||
output_list.append(data[i]["system_output"])
|
||||
if dim == "relevance":
|
||||
ref_list.append(data[i]["reference"])
|
||||
input_list = add_question(dimension=dim, output=output_list, src=src_list, ref=ref_list, task=self.task)
|
||||
score = self.scorer.score(input_list, self.task, category, dim)
|
||||
|
||||
# Please customize other dimensions here for summarization
|
||||
else:
|
||||
raise NotImplementedError('The input format for this dimension is still undefined. \
|
||||
Please customize it first.')
|
||||
raise NotImplementedError(
|
||||
"The input format for this dimension is still undefined. \
|
||||
Please customize it first."
|
||||
)
|
||||
|
||||
for i in range(n_data):
|
||||
eval_scores[i][dim] = score[i]
|
||||
@ -106,35 +108,35 @@ class SumEvaluator:
|
||||
# Customize your overall score here.
|
||||
if overall == True:
|
||||
for i in range(n_data):
|
||||
eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values()))
|
||||
eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values()))
|
||||
|
||||
return eval_scores
|
||||
|
||||
|
||||
class DialogEvaluator:
|
||||
|
||||
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
|
||||
""" Set up evaluator for dialogues """
|
||||
def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
|
||||
"""Set up evaluator for dialogues"""
|
||||
self.scorer = UniEvaluator(
|
||||
model_name_or_path='MingZhong/unieval-dialog' if model_name_or_path == "" else model_name_or_path,
|
||||
model_name_or_path="MingZhong/unieval-dialog" if model_name_or_path == "" else model_name_or_path,
|
||||
max_length=max_length,
|
||||
device=device,
|
||||
cache_dir=cache_dir)
|
||||
self.task = 'dialogue'
|
||||
self.dimensions = ['naturalness', 'coherence', 'engagingness', 'groundedness', 'understandability']
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
self.task = "dialogue"
|
||||
self.dimensions = ["naturalness", "coherence", "engagingness", "groundedness", "understandability"]
|
||||
|
||||
def evaluate(self, data, category, dims=None, overall=True):
|
||||
"""
|
||||
Get the scores of all the given dimensions
|
||||
Get the scores of all the given dimensions
|
||||
|
||||
category: The category to be evaluated.
|
||||
category: The category to be evaluated.
|
||||
|
||||
dims: A list of dimensions to be evaluated. If dims is None, DialogEvaluator will evaluate
|
||||
five dimensions: naturalness, coherence, engagingness, groundedness and understandability.
|
||||
dims: A list of dimensions to be evaluated. If dims is None, DialogEvaluator will evaluate
|
||||
five dimensions: naturalness, coherence, engagingness, groundedness and understandability.
|
||||
|
||||
overall: indicates whether the overall score is to be calculated.
|
||||
Overall score can be customized to a combination of scores based on different
|
||||
dimensions. The default here is the average score of all the given dimensions.
|
||||
overall: indicates whether the overall score is to be calculated.
|
||||
Overall score can be customized to a combination of scores based on different
|
||||
dimensions. The default here is the average score of all the given dimensions.
|
||||
"""
|
||||
n_data = len(data)
|
||||
eval_scores = [{} for _ in range(n_data)]
|
||||
@ -147,50 +149,48 @@ class DialogEvaluator:
|
||||
|
||||
for dim in eval_dims:
|
||||
# Calculate summation score for 'engagingness'
|
||||
if dim == 'engagingness':
|
||||
if dim == "engagingness":
|
||||
src_list, output_list, context_list = [], [], []
|
||||
n_sents = [] # the number of sentences in each generated response
|
||||
n_sents = [] # the number of sentences in each generated response
|
||||
for i in range(n_data):
|
||||
source = data[i]['source']
|
||||
context = data[i]['context']
|
||||
system_outputs = sent_tokenize(data[i]['system_output'])
|
||||
source = data[i]["source"]
|
||||
context = data[i]["context"]
|
||||
system_outputs = sent_tokenize(data[i]["system_output"])
|
||||
n_sents.append(len(system_outputs))
|
||||
for j in range(len(system_outputs)):
|
||||
src_list.append(source)
|
||||
context_list.append(context)
|
||||
output_list.append(system_outputs[j])
|
||||
input_list = add_question(dimension=dim,
|
||||
output=output_list,
|
||||
src=src_list,
|
||||
context=context_list,
|
||||
task=self.task)
|
||||
input_list = add_question(
|
||||
dimension=dim, output=output_list, src=src_list, context=context_list, task=self.task
|
||||
)
|
||||
sent_score = self.scorer.score(input_list, self.task, category, dim)
|
||||
|
||||
# Get the summation score for each sample
|
||||
start_idx = 0
|
||||
score = []
|
||||
for cur_n_sent in n_sents:
|
||||
score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]))
|
||||
score.append(sum(sent_score[start_idx : start_idx + cur_n_sent]))
|
||||
start_idx += cur_n_sent
|
||||
|
||||
# Calculate turn-level score for other dimensions
|
||||
elif dim in ['naturalness', 'coherence', 'groundedness', 'understandability']:
|
||||
elif dim in ["naturalness", "coherence", "groundedness", "understandability"]:
|
||||
src_list, output_list, context_list = [], [], []
|
||||
for i in range(n_data):
|
||||
src_list.append(data[i]['source'])
|
||||
output_list.append(data[i]['system_output'])
|
||||
context_list.append(data[i]['context'])
|
||||
input_list = add_question(dimension=dim,
|
||||
output=output_list,
|
||||
src=src_list,
|
||||
context=context_list,
|
||||
task=self.task)
|
||||
src_list.append(data[i]["source"])
|
||||
output_list.append(data[i]["system_output"])
|
||||
context_list.append(data[i]["context"])
|
||||
input_list = add_question(
|
||||
dimension=dim, output=output_list, src=src_list, context=context_list, task=self.task
|
||||
)
|
||||
score = self.scorer.score(input_list, self.task, category, dim)
|
||||
|
||||
# Please customize other dimensions here for summarization
|
||||
else:
|
||||
raise NotImplementedError('The input format for this dimension is still undefined. \
|
||||
Please customize it first.')
|
||||
raise NotImplementedError(
|
||||
"The input format for this dimension is still undefined. \
|
||||
Please customize it first."
|
||||
)
|
||||
|
||||
for i in range(n_data):
|
||||
eval_scores[i][dim] = score[i]
|
||||
@ -198,35 +198,35 @@ class DialogEvaluator:
|
||||
# Customize your overall score here.
|
||||
if overall == True:
|
||||
for i in range(n_data):
|
||||
eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values()))
|
||||
eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values()))
|
||||
|
||||
return eval_scores
|
||||
|
||||
|
||||
class D2tEvaluator:
|
||||
|
||||
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
|
||||
""" Set up evaluator for data-to-text """
|
||||
def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
|
||||
"""Set up evaluator for data-to-text"""
|
||||
self.scorer = UniEvaluator(
|
||||
model_name_or_path='MingZhong/unieval-sum' if model_name_or_path == "" else model_name_or_path,
|
||||
model_name_or_path="MingZhong/unieval-sum" if model_name_or_path == "" else model_name_or_path,
|
||||
max_length=max_length,
|
||||
device=device,
|
||||
cache_dir=cache_dir)
|
||||
self.task = 'data2text'
|
||||
self.dimensions = ['naturalness', 'informativeness']
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
self.task = "data2text"
|
||||
self.dimensions = ["naturalness", "informativeness"]
|
||||
|
||||
def evaluate(self, data, category, dims=None, overall=True):
|
||||
"""
|
||||
Get the scores of all the given dimensions
|
||||
Get the scores of all the given dimensions
|
||||
|
||||
category: The category to be evaluated.
|
||||
category: The category to be evaluated.
|
||||
|
||||
dims: A list of dimensions to be evaluated. If dims is None, D2tEvaluator will evaluate
|
||||
two dimensions: naturalness and informativeness.
|
||||
dims: A list of dimensions to be evaluated. If dims is None, D2tEvaluator will evaluate
|
||||
two dimensions: naturalness and informativeness.
|
||||
|
||||
overall: indicates whether the overall score is to be calculated.
|
||||
Overall score can be customized to a combination of scores based on different
|
||||
dimensions. The default here is the average score of all the given dimensions.
|
||||
overall: indicates whether the overall score is to be calculated.
|
||||
Overall score can be customized to a combination of scores based on different
|
||||
dimensions. The default here is the average score of all the given dimensions.
|
||||
"""
|
||||
n_data = len(data)
|
||||
eval_scores = [{} for _ in range(n_data)]
|
||||
@ -240,8 +240,8 @@ class D2tEvaluator:
|
||||
for dim in eval_dims:
|
||||
output_list, ref_list = [], []
|
||||
for i in range(n_data):
|
||||
output_list.append(data[i]['system_output'])
|
||||
ref_list.append(data[i]['reference'])
|
||||
output_list.append(data[i]["system_output"])
|
||||
ref_list.append(data[i]["reference"])
|
||||
|
||||
input_list = add_question(dimension=dim, output=output_list, ref=ref_list, task=self.task)
|
||||
score = self.scorer.score(input_list, self.task, category, dim)
|
||||
@ -252,38 +252,38 @@ class D2tEvaluator:
|
||||
# Customize your overall score here.
|
||||
if overall == True:
|
||||
for i in range(n_data):
|
||||
eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values()))
|
||||
eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values()))
|
||||
|
||||
return eval_scores
|
||||
|
||||
|
||||
class FactEvaluator:
|
||||
|
||||
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
|
||||
""" Set up evaluator for factual consistency detection """
|
||||
def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
|
||||
"""Set up evaluator for factual consistency detection"""
|
||||
self.scorer = UniEvaluator(
|
||||
model_name_or_path='MingZhong/unieval-fact' if model_name_or_path == "" else model_name_or_path,
|
||||
model_name_or_path="MingZhong/unieval-fact" if model_name_or_path == "" else model_name_or_path,
|
||||
max_length=max_length,
|
||||
device=device,
|
||||
cache_dir=cache_dir)
|
||||
self.task = 'fact'
|
||||
self.dim = 'consistency'
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
self.task = "fact"
|
||||
self.dim = "consistency"
|
||||
|
||||
def evaluate(self, data, category):
|
||||
"""
|
||||
Get the factual consistency score (only 1 dimension for this task)
|
||||
Get the factual consistency score (only 1 dimension for this task)
|
||||
|
||||
category: The category to be evaluated.
|
||||
category: The category to be evaluated.
|
||||
"""
|
||||
n_data = len(data)
|
||||
eval_scores = [{} for _ in range(n_data)]
|
||||
|
||||
# Calculate average sentence-level scores for factual consistency
|
||||
src_list, output_list = [], []
|
||||
n_sents = [] # the number of sentences in the claim
|
||||
n_sents = [] # the number of sentences in the claim
|
||||
for i in range(n_data):
|
||||
source = data[i]['source']
|
||||
system_outputs = sent_tokenize(data[i]['system_output'])
|
||||
source = data[i]["source"]
|
||||
system_outputs = sent_tokenize(data[i]["system_output"])
|
||||
n_sents.append(len(system_outputs))
|
||||
for j in range(len(system_outputs)):
|
||||
src_list.append(source)
|
||||
@ -295,7 +295,7 @@ class FactEvaluator:
|
||||
start_idx = 0
|
||||
score = []
|
||||
for cur_n_sent in n_sents:
|
||||
score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]) / cur_n_sent)
|
||||
score.append(sum(sent_score[start_idx : start_idx + cur_n_sent]) / cur_n_sent)
|
||||
start_idx += cur_n_sent
|
||||
|
||||
for i in range(n_data):
|
||||
@ -304,28 +304,26 @@ class FactEvaluator:
|
||||
return eval_scores
|
||||
|
||||
|
||||
def get_evaluator(task, model_name_or_path="", max_length=1024, device='cuda:0', cache_dir=None):
|
||||
assert task in ['summarization', 'dialogue', 'data2text', 'fact']
|
||||
if task == 'summarization':
|
||||
return SumEvaluator(model_name_or_path=model_name_or_path,
|
||||
max_length=max_length,
|
||||
device=device,
|
||||
cache_dir=cache_dir)
|
||||
elif task == 'dialogue':
|
||||
return DialogEvaluator(model_name_or_path=model_name_or_path,
|
||||
max_length=max_length,
|
||||
device=device,
|
||||
cache_dir=cache_dir)
|
||||
elif task == 'data2text':
|
||||
return D2tEvaluator(model_name_or_path=model_name_or_path,
|
||||
max_length=max_length,
|
||||
device=device,
|
||||
cache_dir=cache_dir)
|
||||
elif task == 'fact':
|
||||
return FactEvaluator(model_name_or_path=model_name_or_path,
|
||||
max_length=max_length,
|
||||
device=device,
|
||||
cache_dir=cache_dir)
|
||||
def get_evaluator(task, model_name_or_path="", max_length=1024, device="cuda:0", cache_dir=None):
|
||||
assert task in ["summarization", "dialogue", "data2text", "fact"]
|
||||
if task == "summarization":
|
||||
return SumEvaluator(
|
||||
model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir
|
||||
)
|
||||
elif task == "dialogue":
|
||||
return DialogEvaluator(
|
||||
model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir
|
||||
)
|
||||
elif task == "data2text":
|
||||
return D2tEvaluator(
|
||||
model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir
|
||||
)
|
||||
elif task == "fact":
|
||||
return FactEvaluator(
|
||||
model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError('Other tasks are not implemented, \
|
||||
please customize specific tasks here.')
|
||||
raise NotImplementedError(
|
||||
"Other tasks are not implemented, \
|
||||
please customize specific tasks here."
|
||||
)
|
||||
|
@ -27,9 +27,8 @@ from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
|
||||
class UniEvaluator:
|
||||
|
||||
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
|
||||
""" Set up model """
|
||||
def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
|
||||
"""Set up model"""
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
|
||||
@ -47,8 +46,8 @@ class UniEvaluator:
|
||||
|
||||
def score(self, inputs, task, category, dim, batch_size=8):
|
||||
"""
|
||||
Get scores for the given samples.
|
||||
final_score = postive_score / (postive_score + negative_score)
|
||||
Get scores for the given samples.
|
||||
final_score = postive_score / (postive_score + negative_score)
|
||||
"""
|
||||
|
||||
# The implementation of "forward" in T5 still requires decoder_input_ids.
|
||||
@ -58,31 +57,27 @@ class UniEvaluator:
|
||||
|
||||
pos_score_list, neg_score_list = [], []
|
||||
for i in tqdm(range(0, len(inputs), batch_size), desc=f"{category}-({dim}-{task}): "):
|
||||
src_list = inputs[i:i + batch_size]
|
||||
tgt_list = tgts[i:i + batch_size]
|
||||
src_list = inputs[i : i + batch_size]
|
||||
tgt_list = tgts[i : i + batch_size]
|
||||
try:
|
||||
with torch.no_grad():
|
||||
encoded_src = self.tokenizer(src_list,
|
||||
max_length=self.max_length,
|
||||
truncation=True,
|
||||
padding=True,
|
||||
return_tensors='pt')
|
||||
encoded_tgt = self.tokenizer(tgt_list,
|
||||
max_length=self.max_length,
|
||||
truncation=True,
|
||||
padding=True,
|
||||
return_tensors='pt')
|
||||
encoded_src = self.tokenizer(
|
||||
src_list, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt"
|
||||
)
|
||||
encoded_tgt = self.tokenizer(
|
||||
tgt_list, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
src_tokens = encoded_src['input_ids'].to(self.device)
|
||||
src_mask = encoded_src['attention_mask'].to(self.device)
|
||||
src_tokens = encoded_src["input_ids"].to(self.device)
|
||||
src_mask = encoded_src["attention_mask"].to(self.device)
|
||||
|
||||
tgt_tokens = encoded_tgt['input_ids'].to(self.device)[:, 0].unsqueeze(-1)
|
||||
tgt_tokens = encoded_tgt["input_ids"].to(self.device)[:, 0].unsqueeze(-1)
|
||||
|
||||
output = self.model(input_ids=src_tokens, attention_mask=src_mask, labels=tgt_tokens)
|
||||
logits = output.logits.view(-1, self.model.config.vocab_size)
|
||||
|
||||
pos_score = self.softmax(logits)[:, self.pos_id] # Yes
|
||||
neg_score = self.softmax(logits)[:, self.neg_id] # No
|
||||
pos_score = self.softmax(logits)[:, self.pos_id] # Yes
|
||||
neg_score = self.softmax(logits)[:, self.neg_id] # No
|
||||
|
||||
cur_pos_score = [x.item() for x in pos_score]
|
||||
cur_neg_score = [x.item() for x in neg_score]
|
||||
@ -90,8 +85,8 @@ class UniEvaluator:
|
||||
neg_score_list += cur_neg_score
|
||||
|
||||
except RuntimeError:
|
||||
print(f'source: {src_list}')
|
||||
print(f'target: {tgt_list}')
|
||||
print(f"source: {src_list}")
|
||||
print(f"target: {tgt_list}")
|
||||
exit(0)
|
||||
|
||||
score_list = []
|
||||
|
@ -31,105 +31,142 @@ import tqdm
|
||||
|
||||
def add_question(dimension, output, src=None, ref=None, context=None, task=None):
|
||||
"""
|
||||
Add questions to generate input in Bool-QA format for UniEval.
|
||||
Add questions to generate input in Bool-QA format for UniEval.
|
||||
|
||||
dimension: specific dimension to be evaluated
|
||||
src: source input for different NLG tasks. For example, source document for summarization
|
||||
and dialogue history for dialogue response generation.
|
||||
output: output text generated by the models
|
||||
ref: human-annotated groundtruth
|
||||
context: the context needed to evaluate several specific dimension. For example,
|
||||
additional factual information when evaluating engagingness and groundedness in dialogues.
|
||||
dimension: specific dimension to be evaluated
|
||||
src: source input for different NLG tasks. For example, source document for summarization
|
||||
and dialogue history for dialogue response generation.
|
||||
output: output text generated by the models
|
||||
ref: human-annotated groundtruth
|
||||
context: the context needed to evaluate several specific dimension. For example,
|
||||
additional factual information when evaluating engagingness and groundedness in dialogues.
|
||||
"""
|
||||
|
||||
input_with_question = []
|
||||
for i in range(len(output)):
|
||||
# For summarization
|
||||
if task == 'summarization':
|
||||
if dimension == 'fluency':
|
||||
cur_input = 'question: Is this a fluent paragraph? </s> paragraph: ' + output[i]
|
||||
elif dimension == 'coherence':
|
||||
cur_input = 'question: Is this a coherent summary to the document? </s> summary: ' + output[
|
||||
i] + ' </s> document: ' + src[i]
|
||||
elif dimension == 'consistency':
|
||||
cur_input = 'question: Is this claim consistent with the document? </s> claim: ' + output[
|
||||
i] + ' </s> document: ' + src[i]
|
||||
elif dimension == 'relevance':
|
||||
cur_input = 'question: Is this summary relevant to the reference? </s> summary: ' + output[
|
||||
i] + ' </s> reference: ' + ref[i]
|
||||
if task == "summarization":
|
||||
if dimension == "fluency":
|
||||
cur_input = "question: Is this a fluent paragraph? </s> paragraph: " + output[i]
|
||||
elif dimension == "coherence":
|
||||
cur_input = (
|
||||
"question: Is this a coherent summary to the document? </s> summary: "
|
||||
+ output[i]
|
||||
+ " </s> document: "
|
||||
+ src[i]
|
||||
)
|
||||
elif dimension == "consistency":
|
||||
cur_input = (
|
||||
"question: Is this claim consistent with the document? </s> claim: "
|
||||
+ output[i]
|
||||
+ " </s> document: "
|
||||
+ src[i]
|
||||
)
|
||||
elif dimension == "relevance":
|
||||
cur_input = (
|
||||
"question: Is this summary relevant to the reference? </s> summary: "
|
||||
+ output[i]
|
||||
+ " </s> reference: "
|
||||
+ ref[i]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'The input format for this dimension is still undefined. Please customize it first.')
|
||||
"The input format for this dimension is still undefined. Please customize it first."
|
||||
)
|
||||
# For dialogues
|
||||
elif task == 'dialogue':
|
||||
if dimension == 'naturalness':
|
||||
cur_input = 'question: Is this a natural response in the dialogue? </s> response: ' + output[i]
|
||||
elif dimension == 'coherence':
|
||||
cur_input = 'question: Is this a coherent response given the dialogue history? </s> response: '\
|
||||
+ output[i] + ' </s> dialogue history: ' + src[i]
|
||||
elif dimension == 'engagingness':
|
||||
cur_input = 'question: Is this an engaging and informative response according to the dialogue history and fact? </s> response: '\
|
||||
+ output[i] + ' </s> dialogue history: ' + src[i] + ' </s> fact: ' + context[i]
|
||||
elif dimension == 'groundedness':
|
||||
cur_input = 'question: Is this response consistent with knowledge in the fact? </s> response: '\
|
||||
+ output[i] + ' </s> fact: ' + context[i]
|
||||
elif dimension == 'understandability':
|
||||
cur_input = 'question: Is this an understandable response in the dialogue? </s> response: ' + output[i]
|
||||
elif task == "dialogue":
|
||||
if dimension == "naturalness":
|
||||
cur_input = "question: Is this a natural response in the dialogue? </s> response: " + output[i]
|
||||
elif dimension == "coherence":
|
||||
cur_input = (
|
||||
"question: Is this a coherent response given the dialogue history? </s> response: "
|
||||
+ output[i]
|
||||
+ " </s> dialogue history: "
|
||||
+ src[i]
|
||||
)
|
||||
elif dimension == "engagingness":
|
||||
cur_input = (
|
||||
"question: Is this an engaging and informative response according to the dialogue history and fact? </s> response: "
|
||||
+ output[i]
|
||||
+ " </s> dialogue history: "
|
||||
+ src[i]
|
||||
+ " </s> fact: "
|
||||
+ context[i]
|
||||
)
|
||||
elif dimension == "groundedness":
|
||||
cur_input = (
|
||||
"question: Is this response consistent with knowledge in the fact? </s> response: "
|
||||
+ output[i]
|
||||
+ " </s> fact: "
|
||||
+ context[i]
|
||||
)
|
||||
elif dimension == "understandability":
|
||||
cur_input = "question: Is this an understandable response in the dialogue? </s> response: " + output[i]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'The input format for this dimension is still undefined. Please customize it first.')
|
||||
"The input format for this dimension is still undefined. Please customize it first."
|
||||
)
|
||||
# For data-to-text
|
||||
elif task == 'data2text':
|
||||
if dimension == 'naturalness':
|
||||
cur_input = 'question: Is this a fluent utterance? </s> utterance: ' + output[i]
|
||||
elif dimension == 'informativeness':
|
||||
cur_input = 'question: Is this sentence informative according to the reference? </s> sentence: '\
|
||||
+ output[i] + ' </s> reference: ' + ref[i]
|
||||
elif task == "data2text":
|
||||
if dimension == "naturalness":
|
||||
cur_input = "question: Is this a fluent utterance? </s> utterance: " + output[i]
|
||||
elif dimension == "informativeness":
|
||||
cur_input = (
|
||||
"question: Is this sentence informative according to the reference? </s> sentence: "
|
||||
+ output[i]
|
||||
+ " </s> reference: "
|
||||
+ ref[i]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'The input format for this dimension is still undefined. Please customize it first.')
|
||||
"The input format for this dimension is still undefined. Please customize it first."
|
||||
)
|
||||
# For factual consistency detection
|
||||
elif task == 'fact':
|
||||
if dimension == 'consistency':
|
||||
cur_input = 'question: Is this claim consistent with the document? </s> claim: ' + output[
|
||||
i] + ' </s> document: ' + src[i]
|
||||
elif task == "fact":
|
||||
if dimension == "consistency":
|
||||
cur_input = (
|
||||
"question: Is this claim consistent with the document? </s> claim: "
|
||||
+ output[i]
|
||||
+ " </s> document: "
|
||||
+ src[i]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError('No other dimensions for the factual consistency detection task.')
|
||||
raise NotImplementedError("No other dimensions for the factual consistency detection task.")
|
||||
# For new customized tasks
|
||||
else:
|
||||
raise NotImplementedError('Other tasks are not implemented, please customize specific tasks here.')
|
||||
raise NotImplementedError("Other tasks are not implemented, please customize specific tasks here.")
|
||||
input_with_question.append(cur_input)
|
||||
return input_with_question
|
||||
|
||||
|
||||
def convert_data_to_unieval_format(output_list, src_list=None, ref_list=None):
|
||||
"""
|
||||
Convert the data into the unieval's format.
|
||||
Convert the data into the unieval's format.
|
||||
|
||||
output_list: a list of model output
|
||||
output_list: a list of model output
|
||||
|
||||
src_list: source input for different NLG tasks. For example, source document for summarization
|
||||
and dialogue history for dialogue response generation
|
||||
ref_list: human-annotated groundtruth
|
||||
src_list: source input for different NLG tasks. For example, source document for summarization
|
||||
and dialogue history for dialogue response generation
|
||||
ref_list: human-annotated groundtruth
|
||||
"""
|
||||
json_data = []
|
||||
for i in range(len(output_list)):
|
||||
cur = {}
|
||||
cur['system_output'] = output_list[i]
|
||||
cur["system_output"] = output_list[i]
|
||||
if src_list is not None:
|
||||
cur['source'] = src_list[i]
|
||||
cur["source"] = src_list[i]
|
||||
if ref_list is not None:
|
||||
cur['reference'] = ref_list[i]
|
||||
cur['context'] = ""
|
||||
cur["reference"] = ref_list[i]
|
||||
cur["context"] = ""
|
||||
json_data.append(cur)
|
||||
return json_data
|
||||
|
||||
|
||||
def calculate_average_score(scores):
|
||||
"""
|
||||
Calculate average scores for different metrics
|
||||
Calculate average scores for different metrics
|
||||
|
||||
scores: a list of scores for different metrics for each answer
|
||||
scores: a list of scores for different metrics for each answer
|
||||
|
||||
"""
|
||||
metrics = {metric: 0 for metric in scores[0]}
|
||||
@ -226,9 +263,9 @@ def analyze_unieval_results(results_path: str, save_path: str) -> None:
|
||||
frame_all.to_csv(os.path.join(save_path, "unieval_statistics.csv"))
|
||||
|
||||
for metric in tqdm.tqdm(
|
||||
frame_per_metric.keys(),
|
||||
desc=f"UniEval metrics: ",
|
||||
total=len(frame_per_metric.keys()),
|
||||
frame_per_metric.keys(),
|
||||
desc=f"UniEval metrics: ",
|
||||
total=len(frame_per_metric.keys()),
|
||||
):
|
||||
data = pd.DataFrame(frame_per_metric[metric])
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
from typing import Dict
|
||||
|
||||
@ -55,7 +54,7 @@ def jload(f, mode="r"):
|
||||
|
||||
|
||||
def get_json_list(file_path):
|
||||
with open(file_path, 'r') as f:
|
||||
with open(file_path, "r") as f:
|
||||
json_list = []
|
||||
for line in f:
|
||||
json_list.append(json.loads(line))
|
||||
@ -187,9 +186,9 @@ def analyze_automatic_results(results_path: str, save_path: str) -> None:
|
||||
frame_all.to_csv(os.path.join(save_path, "automatic_evaluation_statistics.csv"))
|
||||
|
||||
for metric in tqdm.tqdm(
|
||||
frame_per_metric.keys(),
|
||||
desc=f"automatic metrics: ",
|
||||
total=len(frame_per_metric.keys()),
|
||||
frame_per_metric.keys(),
|
||||
desc=f"automatic metrics: ",
|
||||
total=len(frame_per_metric.keys()),
|
||||
):
|
||||
data = pd.DataFrame(frame_per_metric[metric])
|
||||
|
||||
|
@ -3,7 +3,6 @@ import json
|
||||
from typing import Dict, Sequence
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
@ -20,7 +19,8 @@ def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: i
|
||||
padding="longest",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
) for text in strings
|
||||
)
|
||||
for text in strings
|
||||
]
|
||||
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
||||
input_ids_lens = labels_lens = [
|
||||
@ -48,18 +48,17 @@ def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTo
|
||||
|
||||
|
||||
class EasySupervisedDataset(Dataset):
|
||||
|
||||
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None:
|
||||
super(EasySupervisedDataset, self).__init__()
|
||||
with open(data_file, "r", encoding="UTF-8") as f:
|
||||
all_lines = f.readlines()
|
||||
#split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
|
||||
# split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
|
||||
sources, targets = [], []
|
||||
for line in all_lines:
|
||||
if "回答:" in line:
|
||||
sep_index = line.index("回答:")
|
||||
sources.append(line[:sep_index + 3])
|
||||
targets.append(line[sep_index + 3:] + tokenizer.eos_token)
|
||||
sources.append(line[: sep_index + 3])
|
||||
targets.append(line[sep_index + 3 :] + tokenizer.eos_token)
|
||||
else:
|
||||
sources.append(line)
|
||||
targets.append("" + tokenizer.eos_token)
|
||||
@ -83,15 +82,17 @@ class EasySupervisedDataset(Dataset):
|
||||
|
||||
|
||||
class EasyPromptsDataset(Dataset):
|
||||
|
||||
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None:
|
||||
super(EasyPromptsDataset, self).__init__()
|
||||
with open(data_file, "r", encoding="UTF-8") as f:
|
||||
all_lines = f.readlines()
|
||||
all_lines = [line if "回答:" not in line else line[:line.index("回答:") + 3] for line in all_lines]
|
||||
all_lines = [line if "回答:" not in line else line[: line.index("回答:") + 3] for line in all_lines]
|
||||
self.prompts = [
|
||||
tokenizer(line, return_tensors='pt', max_length=max_length, padding='max_length',
|
||||
truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0)
|
||||
tokenizer(line, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True)[
|
||||
"input_ids"
|
||||
]
|
||||
.to(torch.cuda.current_device())
|
||||
.squeeze(0)
|
||||
for line in tqdm(all_lines)
|
||||
]
|
||||
self.data_file = data_file
|
||||
@ -110,7 +111,6 @@ class EasyPromptsDataset(Dataset):
|
||||
|
||||
|
||||
class EasyRewardDataset(Dataset):
|
||||
|
||||
def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None:
|
||||
super(EasyRewardDataset, self).__init__()
|
||||
self.chosen = []
|
||||
@ -120,44 +120,42 @@ class EasyRewardDataset(Dataset):
|
||||
else:
|
||||
self.end_token = special_token
|
||||
print(self.end_token)
|
||||
#read all lines in the train_file to a list
|
||||
# read all lines in the train_file to a list
|
||||
with open(train_file, "r", encoding="UTF-8") as f:
|
||||
all_lines = f.readlines()
|
||||
for line in tqdm(all_lines):
|
||||
data = json.loads(line)
|
||||
prompt = "提问:" + data['prompt'] + " 回答:"
|
||||
prompt = "提问:" + data["prompt"] + " 回答:"
|
||||
|
||||
chosen = prompt + data['chosen'] + self.end_token
|
||||
chosen_token = tokenizer(chosen,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.chosen.append({
|
||||
"input_ids": chosen_token['input_ids'],
|
||||
"attention_mask": chosen_token['attention_mask']
|
||||
})
|
||||
chosen = prompt + data["chosen"] + self.end_token
|
||||
chosen_token = tokenizer(
|
||||
chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
self.chosen.append(
|
||||
{"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
|
||||
)
|
||||
|
||||
reject = prompt + data['rejected'] + self.end_token
|
||||
reject_token = tokenizer(reject,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.reject.append({
|
||||
"input_ids": reject_token['input_ids'],
|
||||
"attention_mask": reject_token['attention_mask']
|
||||
})
|
||||
reject = prompt + data["rejected"] + self.end_token
|
||||
reject_token = tokenizer(
|
||||
reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
self.reject.append(
|
||||
{"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
length = len(self.chosen)
|
||||
return length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
|
||||
"input_ids"], self.reject[idx]["attention_mask"]
|
||||
return (
|
||||
self.chosen[idx]["input_ids"],
|
||||
self.chosen[idx]["attention_mask"],
|
||||
self.reject[idx]["input_ids"],
|
||||
self.reject[idx]["attention_mask"],
|
||||
)
|
||||
|
||||
#python representation of the object and the string representation of the object
|
||||
# python representation of the object and the string representation of the object
|
||||
def __repr__(self):
|
||||
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
|
||||
|
||||
@ -165,26 +163,25 @@ class EasyRewardDataset(Dataset):
|
||||
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
|
||||
|
||||
|
||||
'''
|
||||
"""
|
||||
Easy SFT just accept a text file which can be read line by line. However the datasets will group texts together to max_length so LLM will learn the texts meaning better.
|
||||
If individual lines are not related, just set is_group_texts to False.
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
class EasySFTDataset(Dataset):
|
||||
|
||||
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None:
|
||||
super().__init__()
|
||||
#read the data_file line by line
|
||||
# read the data_file line by line
|
||||
with open(data_file, "r", encoding="UTF-8") as f:
|
||||
#encode the text data line by line and put raw python list input_ids only to raw_input_ids list
|
||||
# encode the text data line by line and put raw python list input_ids only to raw_input_ids list
|
||||
raw_input_ids = []
|
||||
for line in f:
|
||||
encoded_ids = tokenizer.encode(line)
|
||||
#if the encoded_ids is longer than max_length, then split it into several parts
|
||||
# if the encoded_ids is longer than max_length, then split it into several parts
|
||||
if len(encoded_ids) > max_length:
|
||||
for i in range(0, len(encoded_ids), max_length):
|
||||
raw_input_ids.append(encoded_ids[i:i + max_length])
|
||||
raw_input_ids.append(encoded_ids[i : i + max_length])
|
||||
else:
|
||||
raw_input_ids.append(encoded_ids)
|
||||
|
||||
@ -196,12 +193,13 @@ class EasySFTDataset(Dataset):
|
||||
if is_group_texts:
|
||||
for input_ids in raw_input_ids:
|
||||
if len(current_input_ids) + len(input_ids) > max_length:
|
||||
#pad the current_input_ids to max_length with tokenizer.pad_token_id
|
||||
# pad the current_input_ids to max_length with tokenizer.pad_token_id
|
||||
padded_length = max_length - len(current_input_ids)
|
||||
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
|
||||
attention_mask.append(
|
||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
|
||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
|
||||
)
|
||||
current_input_ids = []
|
||||
else:
|
||||
current_input_ids.extend(input_ids)
|
||||
@ -210,14 +208,16 @@ class EasySFTDataset(Dataset):
|
||||
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
|
||||
attention_mask.append(
|
||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
|
||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
|
||||
)
|
||||
else:
|
||||
#just append the raw_input_ids to max_length
|
||||
# just append the raw_input_ids to max_length
|
||||
for input_ids in raw_input_ids:
|
||||
padded_length = max_length - len(input_ids)
|
||||
input_ids.extend([tokenizer.pad_token_id] * padded_length)
|
||||
attention_mask.append(
|
||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
|
||||
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
|
||||
)
|
||||
grouped_input_ids.append(torch.tensor(input_ids, dtype=torch.long))
|
||||
self.input_ids = grouped_input_ids
|
||||
self.labels = copy.deepcopy(self.input_ids)
|
||||
@ -227,14 +227,14 @@ class EasySFTDataset(Dataset):
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
#get item from dataset
|
||||
# get item from dataset
|
||||
def __getitem__(self, idx):
|
||||
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
|
||||
|
||||
#generate the dataset description to be printed by print in python
|
||||
# generate the dataset description to be printed by print in python
|
||||
def __repr__(self):
|
||||
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
|
||||
|
||||
#generate the dataset description to be printed by print in python
|
||||
# generate the dataset description to be printed by print in python
|
||||
def __str__(self):
|
||||
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
|
||||
|
@ -4,7 +4,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from coati.models.generation import generate
|
||||
from coati.models.utils import log_probs_from_logits, masked_mean
|
||||
from coati.models.utils import log_probs_from_logits
|
||||
from peft import PeftModel
|
||||
from torch.nn.modules import Module
|
||||
from transformers import BloomConfig, BloomForCausalLM
|
||||
@ -24,38 +24,33 @@ class Actor(Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
return_action_mask: bool = True,
|
||||
**kwargs
|
||||
self, input_ids: torch.Tensor, return_action_mask: bool = True, **kwargs
|
||||
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
||||
sequences = generate(self.model, input_ids, **kwargs)
|
||||
attention_mask = None
|
||||
pad_token_id = kwargs.get('pad_token_id', None)
|
||||
pad_token_id = kwargs.get("pad_token_id", None)
|
||||
if pad_token_id is not None:
|
||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||
if not return_action_mask:
|
||||
return sequences, attention_mask, None
|
||||
input_len = input_ids.size(1)
|
||||
eos_token_id = kwargs.get('eos_token_id', None)
|
||||
eos_token_id = kwargs.get("eos_token_id", None)
|
||||
if eos_token_id is None:
|
||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||
else:
|
||||
# left padding may be applied, only mask action
|
||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||
action_mask[:, :input_len] = False
|
||||
action_mask = action_mask[:, 1:]
|
||||
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
|
||||
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len) :]
|
||||
|
||||
def forward(self,
|
||||
sequences: torch.LongTensor,
|
||||
num_actions: int,
|
||||
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Returns action log probs
|
||||
"""
|
||||
def forward(
|
||||
self, sequences: torch.LongTensor, num_actions: int, attention_mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""Returns action log probs"""
|
||||
output = self.model(sequences, attention_mask=attention_mask)
|
||||
logits = output['logits']
|
||||
logits = output["logits"]
|
||||
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
return log_probs[:, -num_actions:]
|
||||
|
||||
@ -75,11 +70,13 @@ class BLOOMActor(Actor):
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_path: str = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_path: str = None,
|
||||
) -> None:
|
||||
if pretrained is not None:
|
||||
model = BloomForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
|
@ -1,18 +1,16 @@
|
||||
import argparse
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset
|
||||
from coati.dataset import DataCollatorForSupervisedDataset
|
||||
from coati.models.bloom import BLOOMRM, BLOOMCritic
|
||||
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
|
||||
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
|
||||
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
||||
from coati.models.gpt import GPTRM, GPTCritic
|
||||
from coati.models.llama import LlamaCritic, LlamaRM
|
||||
from coati.models.opt import OPTRM, OPTCritic
|
||||
from coati.trainer import PPOTrainer
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from easy_dataset import EasyPromptsDataset, EasySupervisedDataset
|
||||
from easy_models import BLOOMActor
|
||||
from peft import PeftModel
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
@ -23,24 +21,24 @@ from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
def main(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'ddp':
|
||||
if args.strategy == "ddp":
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
|
||||
elif args.strategy == "colossalai_gemini":
|
||||
strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
|
||||
elif args.strategy == "colossalai_zero2":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
if args.rm_path is not None:
|
||||
state_dict = torch.load(args.rm_path, map_location='cpu')
|
||||
state_dict = torch.load(args.rm_path, map_location="cpu")
|
||||
|
||||
# configure model
|
||||
if args.model == 'bloom':
|
||||
if args.model == "bloom":
|
||||
# initial_model = BLOOMActor(pretrained=args.pretrain)
|
||||
print('Using peft lora to load Bloom model as initial_model')
|
||||
print("Using peft lora to load Bloom model as initial_model")
|
||||
initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
|
||||
print('Using peft lora to load Bloom model as initial_model (Done)')
|
||||
print("Using peft lora to load Bloom model as initial_model (Done)")
|
||||
else:
|
||||
raise ValueError(f'Unsupported actor model "{args.model}"')
|
||||
|
||||
@ -49,59 +47,59 @@ def main(args):
|
||||
else:
|
||||
rm_model_name = args.rm_model
|
||||
|
||||
if rm_model_name == 'gpt2':
|
||||
if rm_model_name == "gpt2":
|
||||
reward_model = GPTRM(pretrained=args.rm_pretrain)
|
||||
elif rm_model_name == 'bloom':
|
||||
elif rm_model_name == "bloom":
|
||||
print("load bloom reward model ", args.rm_pretrain)
|
||||
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
|
||||
elif rm_model_name == 'opt':
|
||||
elif rm_model_name == "opt":
|
||||
reward_model = OPTRM(pretrained=args.rm_pretrain)
|
||||
elif rm_model_name == 'llama':
|
||||
elif rm_model_name == "llama":
|
||||
reward_model = LlamaRM(pretrained=args.rm_pretrain)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
|
||||
|
||||
if args.rm_path is not None:
|
||||
print('Loading reward model from', args.rm_path)
|
||||
print("Loading reward model from", args.rm_path)
|
||||
reward_model.load_state_dict(state_dict)
|
||||
|
||||
if args.strategy != 'colossalai_gemini':
|
||||
if args.strategy != "colossalai_gemini":
|
||||
initial_model.to(torch.float16).to(torch.cuda.current_device())
|
||||
reward_model.to(torch.float16).to(torch.cuda.current_device())
|
||||
|
||||
with strategy.model_init_context():
|
||||
if args.model == 'bloom':
|
||||
if args.model == "bloom":
|
||||
# actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
print('Using peft lora to load Bloom model as Actor')
|
||||
print("Using peft lora to load Bloom model as Actor")
|
||||
actor = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
|
||||
print('Using peft lora to load Bloom model as Actor (Done)')
|
||||
print("Using peft lora to load Bloom model as Actor (Done)")
|
||||
else:
|
||||
raise ValueError(f'Unsupported actor model "{args.model}"')
|
||||
|
||||
if rm_model_name == 'gpt2':
|
||||
if rm_model_name == "gpt2":
|
||||
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
elif rm_model_name == 'bloom':
|
||||
elif rm_model_name == "bloom":
|
||||
print("load bloom critic ", args.rm_pretrain, " lora_rank ", args.lora_rank, " use_action_mask ", True)
|
||||
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
print("load bloom critic (Done) ")
|
||||
elif rm_model_name == 'opt':
|
||||
elif rm_model_name == "opt":
|
||||
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
elif rm_model_name == 'llama':
|
||||
elif rm_model_name == "llama":
|
||||
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
else:
|
||||
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
|
||||
|
||||
if args.rm_path is not None:
|
||||
print('Loading reward model from', args.rm_path)
|
||||
print("Loading reward model from", args.rm_path)
|
||||
critic.load_state_dict(state_dict)
|
||||
del state_dict
|
||||
|
||||
if args.strategy != 'colossalai_gemini':
|
||||
if args.strategy != "colossalai_gemini":
|
||||
critic.to(torch.float16).to(torch.cuda.current_device())
|
||||
actor.to(torch.float16).to(torch.cuda.current_device())
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
if args.strategy.startswith("colossalai"):
|
||||
actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
|
||||
critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
|
||||
else:
|
||||
@ -109,18 +107,18 @@ def main(args):
|
||||
critic_optim = Adam(critic.parameters(), lr=1e-7)
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
if args.model == "gpt2":
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(args.rm_pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
elif args.model == "bloom":
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.rm_pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
elif args.model == "opt":
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.rm_pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'llama':
|
||||
elif args.model == "llama":
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
|
||||
tokenizer.eos_token = '<\s>'
|
||||
tokenizer.eos_token = "<\s>"
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
@ -132,26 +130,27 @@ def main(args):
|
||||
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
|
||||
else:
|
||||
prompt_sampler = None
|
||||
prompt_dataloader = DataLoader(prompt_dataset,
|
||||
shuffle=(prompt_sampler is None),
|
||||
sampler=prompt_sampler,
|
||||
batch_size=args.train_batch_size)
|
||||
prompt_dataloader = DataLoader(
|
||||
prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.train_batch_size
|
||||
)
|
||||
|
||||
pretrain_dataset = EasySupervisedDataset(args.pretrain_dataset, tokenizer)
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
|
||||
else:
|
||||
pretrain_sampler = None
|
||||
pretrain_dataloader = DataLoader(pretrain_dataset,
|
||||
shuffle=(pretrain_sampler is None),
|
||||
sampler=pretrain_sampler,
|
||||
batch_size=args.ptx_batch_size,
|
||||
collate_fn=data_collator)
|
||||
pretrain_dataloader = DataLoader(
|
||||
pretrain_dataset,
|
||||
shuffle=(pretrain_sampler is None),
|
||||
sampler=pretrain_sampler,
|
||||
batch_size=args.ptx_batch_size,
|
||||
collate_fn=data_collator,
|
||||
)
|
||||
|
||||
def tokenize_fn(texts):
|
||||
# MUST padding to max length to ensure inputs of all ranks have the same length
|
||||
# Different length may lead to hang when using gemini, as different generation steps
|
||||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
|
||||
batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
|
||||
return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()}
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
|
||||
@ -178,45 +177,46 @@ def main(args):
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
trainer.fit(prompt_dataloader=prompt_dataloader,
|
||||
pretrain_dataloader=pretrain_dataloader,
|
||||
num_episodes=args.num_episodes,
|
||||
num_update_steps=args.num_update_steps,
|
||||
num_collect_steps=args.num_collect_steps)
|
||||
trainer.fit(
|
||||
prompt_dataloader=prompt_dataloader,
|
||||
pretrain_dataloader=pretrain_dataloader,
|
||||
num_episodes=args.num_episodes,
|
||||
num_update_steps=args.num_update_steps,
|
||||
num_collect_steps=args.num_collect_steps,
|
||||
)
|
||||
|
||||
# save model checkpoint after fitting
|
||||
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(actor_optim,
|
||||
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
strategy.save_optimizer(
|
||||
actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--prompt_path', type=str, default=None, help='path to the prompt dataset')
|
||||
parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset')
|
||||
parser.add_argument('--strategy',
|
||||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='ddp',
|
||||
help='strategy to use')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--sft_lora_path', type=str, default=None)
|
||||
parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--rm_path', type=str, default=None)
|
||||
parser.add_argument('--rm_pretrain', type=str, default=None)
|
||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--num_episodes', type=int, default=10)
|
||||
parser.add_argument('--num_collect_steps', type=int, default=10)
|
||||
parser.add_argument('--num_update_steps', type=int, default=5)
|
||||
parser.add_argument('--train_batch_size', type=int, default=2)
|
||||
parser.add_argument('--ptx_batch_size', type=int, default=1)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument('--kl_coef', type=float, default=0.1)
|
||||
parser.add_argument('--ptx_coef', type=float, default=0.9)
|
||||
parser.add_argument("--prompt_path", type=str, default=None, help="path to the prompt dataset")
|
||||
parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
|
||||
parser.add_argument(
|
||||
"--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp", help="strategy to use"
|
||||
)
|
||||
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--sft_lora_path", type=str, default=None)
|
||||
parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"])
|
||||
parser.add_argument("--rm_path", type=str, default=None)
|
||||
parser.add_argument("--rm_pretrain", type=str, default=None)
|
||||
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
|
||||
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
|
||||
parser.add_argument("--num_episodes", type=int, default=10)
|
||||
parser.add_argument("--num_collect_steps", type=int, default=10)
|
||||
parser.add_argument("--num_update_steps", type=int, default=5)
|
||||
parser.add_argument("--train_batch_size", type=int, default=2)
|
||||
parser.add_argument("--ptx_batch_size", type=int, default=1)
|
||||
parser.add_argument("--experience_batch_size", type=int, default=8)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument("--kl_coef", type=float, default=0.1)
|
||||
parser.add_argument("--ptx_coef", type=float, default=0.9)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -1,18 +1,10 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import loralib as lora
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
|
||||
from coati.models.base import RewardModel
|
||||
from coati.models.bloom import BLOOMLM
|
||||
from coati.models.gpt import GPTLM
|
||||
from coati.models.llama import LlamaLM
|
||||
from coati.models.opt import OPTLM
|
||||
from coati.trainer import SFTTrainer
|
||||
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
|
||||
from datasets import load_dataset
|
||||
from easy_dataset import EasyDataset
|
||||
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
|
||||
from torch.optim import Adam
|
||||
@ -29,75 +21,76 @@ from colossalai.tensor import ColoParameter
|
||||
|
||||
def train(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'ddp':
|
||||
if args.strategy == "ddp":
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = GeminiStrategy(placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
|
||||
elif args.strategy == "colossalai_gemini":
|
||||
strategy = GeminiStrategy(placement_policy="cuda")
|
||||
elif args.strategy == "colossalai_zero2":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
# configure model
|
||||
with strategy.model_init_context():
|
||||
print('Warning: currently only bloom is tested, gpt2,llama and opt are not tested')
|
||||
print("Warning: currently only bloom is tested, gpt2,llama and opt are not tested")
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device())
|
||||
# if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json
|
||||
if os.path.exists(args.save_path) and os.path.exists(args.save_path + '/adapter_config.json') \
|
||||
and os.path.exists(args.save_path + '/adapter_model.bin'):
|
||||
if (
|
||||
os.path.exists(args.save_path)
|
||||
and os.path.exists(args.save_path + "/adapter_config.json")
|
||||
and os.path.exists(args.save_path + "/adapter_model.bin")
|
||||
):
|
||||
print("loading from saved peft model ", args.save_path)
|
||||
model = PeftModel.from_pretrained(model, args.save_path)
|
||||
else:
|
||||
# we'll use peft lora library to do the lora
|
||||
lora_rank = args.lora_rank if args.lora_rank > 0 else 32
|
||||
# config lora with rank of lora_rank
|
||||
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
r=lora_rank,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.1)
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=32, lora_dropout=0.1
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
if args.model == "gpt2":
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
elif args.model == "bloom":
|
||||
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
elif args.model == "opt":
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'llama':
|
||||
elif args.model == "llama":
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.pretrain,
|
||||
padding_side="right",
|
||||
use_fast=False,
|
||||
)
|
||||
tokenizer.eos_token = '<\s>'
|
||||
tokenizer.eos_token = "<\s>"
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
if args.model == 'llama' and args.strategy == 'colossalai_gemini':
|
||||
if args.model == "llama" and args.strategy == "colossalai_gemini":
|
||||
# this is a hack to deal with the resized embedding
|
||||
# to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
|
||||
for name, param in model.named_parameters():
|
||||
if not isinstance(param, ColoParameter):
|
||||
sub_module_name = '.'.join(name.split('.')[:-1])
|
||||
weight_name = name.split('.')[-1]
|
||||
sub_module_name = ".".join(name.split(".")[:-1])
|
||||
weight_name = name.split(".")[-1]
|
||||
sub_module = model.get_submodule(sub_module_name)
|
||||
setattr(sub_module, weight_name, ColoParameter(param))
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
if args.strategy.startswith("colossalai"):
|
||||
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
|
||||
else:
|
||||
optim = Adam(model.parameters(), lr=args.lr)
|
||||
|
||||
logger = get_dist_logger()
|
||||
logger.set_level('WARNING')
|
||||
logger.set_level("WARNING")
|
||||
|
||||
# configure dataset
|
||||
law_dataset = EasyDataset(args.dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
|
||||
@ -108,47 +101,57 @@ def train(args):
|
||||
eval_dataset = EasyDataset(args.eval_dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text)
|
||||
data_collator = default_collate
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
train_sampler = DistributedSampler(train_dataset,
|
||||
shuffle=True,
|
||||
seed=42,
|
||||
drop_last=True,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size())
|
||||
train_sampler = DistributedSampler(
|
||||
train_dataset,
|
||||
shuffle=True,
|
||||
seed=42,
|
||||
drop_last=True,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size(),
|
||||
)
|
||||
if eval_dataset is not None:
|
||||
eval_sampler = DistributedSampler(eval_dataset,
|
||||
shuffle=False,
|
||||
seed=42,
|
||||
drop_last=False,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size())
|
||||
eval_sampler = DistributedSampler(
|
||||
eval_dataset,
|
||||
shuffle=False,
|
||||
seed=42,
|
||||
drop_last=False,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size(),
|
||||
)
|
||||
else:
|
||||
train_sampler = None
|
||||
eval_sampler = None
|
||||
|
||||
train_dataloader = DataLoader(train_dataset,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=data_collator,
|
||||
pin_memory=True)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=data_collator,
|
||||
pin_memory=True,
|
||||
)
|
||||
if eval_dataset is not None:
|
||||
eval_dataloader = DataLoader(eval_dataset,
|
||||
shuffle=(eval_sampler is None),
|
||||
sampler=eval_sampler,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=data_collator,
|
||||
pin_memory=True)
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset,
|
||||
shuffle=(eval_sampler is None),
|
||||
sampler=eval_sampler,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=data_collator,
|
||||
pin_memory=True,
|
||||
)
|
||||
else:
|
||||
eval_dataloader = None
|
||||
|
||||
trainer = SFTTrainer(model=model,
|
||||
strategy=strategy,
|
||||
optim=optim,
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
batch_size=args.batch_size,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps)
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
strategy=strategy,
|
||||
optim=optim,
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
batch_size=args.batch_size,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
)
|
||||
|
||||
trainer.fit(logger=logger, log_interval=args.log_interval)
|
||||
|
||||
@ -156,29 +159,27 @@ def train(args):
|
||||
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(trainer.optimizer,
|
||||
'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
strategy.save_optimizer(
|
||||
trainer.optimizer, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device()), only_rank0=False
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--strategy',
|
||||
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='ddp')
|
||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--dataset', type=str, default=None)
|
||||
parser.add_argument('--eval_dataset', type=str, default=None)
|
||||
parser.add_argument('--save_path', type=str, default='output')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--max_epochs', type=int, default=3)
|
||||
parser.add_argument('--batch_size', type=int, default=4)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
|
||||
parser.add_argument('--lr', type=float, default=5e-6)
|
||||
parser.add_argument('--accumulation_steps', type=int, default=8)
|
||||
parser.add_argument('--enable_peft_lora', action='store_true', default=False)
|
||||
parser.add_argument("--is_short_text", action='store_true', default=False)
|
||||
parser.add_argument("--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp")
|
||||
parser.add_argument("--model", choices=["gpt2", "bloom", "opt", "llama"], default="bloom")
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--dataset", type=str, default=None)
|
||||
parser.add_argument("--eval_dataset", type=str, default=None)
|
||||
parser.add_argument("--save_path", type=str, default="output")
|
||||
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log")
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--enable_peft_lora", action="store_true", default=False)
|
||||
parser.add_argument("--is_short_text", action="store_true", default=False)
|
||||
args = parser.parse_args()
|
||||
train(args)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user