[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu 2023-09-19 14:20:26 +08:00 committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1268 changed files with 50037 additions and 38444 deletions

22
.flake8
View File

@ -1,22 +0,0 @@
[flake8]
ignore =
;W503 line break before binary operator
W503,
;E203 whitespace before ':'
E203,
; exclude file
exclude =
.tox,
.git,
__pycache__,
build,
dist,
*.pyc,
*.egg-info,
.cache,
.eggs
max-line-length = 120
per-file-ignores = __init__.py:F401

View File

@ -22,13 +22,13 @@ def compare_dirs(dir1, dir2):
# If the corresponding item doesn't exist in the second directory, the directories are different
if not os.path.exists(item_path2):
print(f'Found mismatch: {item_path1}, {item_path2}')
print(f"Found mismatch: {item_path1}, {item_path2}")
return False
# If the corresponding item is a directory, we compare the two directories recursively
if os.path.isdir(item_path1) and os.path.isdir(item_path2):
if not compare_dirs(item_path1, item_path2):
print(f'Found mismatch: {item_path1}, {item_path2}')
print(f"Found mismatch: {item_path1}, {item_path2}")
return False
# both are files
@ -37,16 +37,16 @@ def compare_dirs(dir1, dir2):
# If the corresponding item is not a file or a directory, the directories are different
else:
print(f'Found mismatch: {item_path1}, {item_path2}')
print(f"Found mismatch: {item_path1}, {item_path2}")
return False
# If all items are the same, the directories are the same
return True
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--directory', help="The directory where the multi-language source files are kept.")
parser.add_argument("-d", "--directory", help="The directory where the multi-language source files are kept.")
args = parser.parse_args()
i18n_folders = os.listdir(args.directory)
@ -56,7 +56,7 @@ if __name__ == '__main__':
for i in range(1, len(i18n_folders)):
dir1 = i18n_folders[0]
dir2 = i18n_folders[i]
print(f'comparing {dir1} vs {dir2}')
print(f"comparing {dir1} vs {dir2}")
match = compare_dirs(i18n_folders[0], i18n_folders[i])
if not match:

View File

@ -4,7 +4,7 @@ import os
def check_inputs(input_list):
for path in input_list:
real_path = os.path.join('examples', path)
real_path = os.path.join("examples", path)
if not os.path.exists(real_path):
return False
return True
@ -12,16 +12,16 @@ def check_inputs(input_list):
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-f', '--fileNameList', type=str, help="List of file names")
parser.add_argument("-f", "--fileNameList", type=str, help="List of file names")
args = parser.parse_args()
name_list = args.fileNameList.split(",")
is_correct = check_inputs(name_list)
if is_correct:
print('success')
print("success")
else:
print('failure')
print("failure")
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -17,21 +17,21 @@ def show_files(path, all_files):
def join(input_list, sep=None):
return (sep or ' ').join(input_list)
return (sep or " ").join(input_list)
def main():
contents = show_files('examples/', [])
contents = show_files("examples/", [])
all_loc = []
for file_loc in contents:
split_loc = file_loc.split('/')
split_loc = file_loc.split("/")
# must have two sub-folder levels after examples folder, such as examples/images/vit is acceptable, examples/images/README.md is not, examples/requirements.txt is not.
if len(split_loc) >= 4:
re_loc = '/'.join(split_loc[1:3])
re_loc = "/".join(split_loc[1:3])
if re_loc not in all_loc:
all_loc.append(re_loc)
print(all_loc)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -3,7 +3,7 @@ import argparse
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-f', '--fileNameList', type=str, help="The list of changed files")
parser.add_argument("-f", "--fileNameList", type=str, help="The list of changed files")
args = parser.parse_args()
name_list = args.fileNameList.split(":")
folder_need_check = set()
@ -15,10 +15,10 @@ def main():
# - application
# - file
if loc.split("/")[0] == "examples" and len(loc.split("/")) >= 4:
folder_need_check.add('/'.join(loc.split("/")[1:3]))
folder_need_check.add("/".join(loc.split("/")[1:3]))
# Output the result using print. Then the shell can get the values.
print(list(folder_need_check))
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -74,16 +74,16 @@ def get_organization_repositories(github_token, organization_name) -> List[str]:
# prepare header
headers = {
'Authorization': f'Bearer {github_token}',
'Accept': 'application/vnd.github+json',
'X-GitHub-Api-Version': '2022-11-28'
"Authorization": f"Bearer {github_token}",
"Accept": "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
}
res = requests.get(url, headers=headers).json()
repo_list = []
for item in res:
repo_list.append(item['name'])
repo_list.append(item["name"])
return repo_list
@ -97,9 +97,9 @@ def get_issue_pull_request_comments(github_token: str, org_name: str, repo_name:
"""
# prepare header
headers = {
'Authorization': f'Bearer {github_token}',
'Accept': 'application/vnd.github+json',
'X-GitHub-Api-Version': '2022-11-28'
"Authorization": f"Bearer {github_token}",
"Accept": "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
}
user_engagement_count = {}
@ -107,28 +107,28 @@ def get_issue_pull_request_comments(github_token: str, org_name: str, repo_name:
# do pagination to the API
page = 1
while True:
comment_api = f'https://api.github.com/repos/{org_name}/{repo_name}/issues/comments?since={since}&page={page}'
comment_api = f"https://api.github.com/repos/{org_name}/{repo_name}/issues/comments?since={since}&page={page}"
comment_response = requests.get(comment_api, headers=headers).json()
if len(comment_response) == 0:
break
else:
for item in comment_response:
comment_author_relationship = item['author_association']
if comment_author_relationship != 'MEMBER':
comment_author_relationship = item["author_association"]
if comment_author_relationship != "MEMBER":
# if the comment is not made by our member
# we don't count this comment towards user engagement
continue
issue_id = item['issue_url'].split('/')[-1]
issue_api = f'https://api.github.com/repos/{org_name}/{repo_name}/issues/{issue_id}'
issue_id = item["issue_url"].split("/")[-1]
issue_api = f"https://api.github.com/repos/{org_name}/{repo_name}/issues/{issue_id}"
issue_response = requests.get(issue_api, headers=headers).json()
issue_author_relationship = issue_response['author_association']
issue_author_relationship = issue_response["author_association"]
if issue_author_relationship != 'MEMBER':
if issue_author_relationship != "MEMBER":
# this means that the issue/PR is not created by our own people
# any comments in this issue/PR by our member will be counted towards the leaderboard
member_name = item['user']['login']
member_name = item["user"]["login"]
if member_name in user_engagement_count:
user_engagement_count[member_name] += 1
@ -153,7 +153,7 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
if cursor is None:
offset_str = ""
else:
offset_str = f", after: \"{cursor}\""
offset_str = f', after: "{cursor}"'
query = f"""
{{
repository(owner: "{org_name}", name: "{repo_name}"){{
@ -182,7 +182,7 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
if cursor is None:
offset_str = ""
else:
offset_str = f", before: \"{cursor}\""
offset_str = f', before: "{cursor}"'
query = f"""
{{
repository(owner: "{org_name}", name: "{repo_name}"){{
@ -220,8 +220,8 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
# a utility function to make call to Github GraphQL API
def _call_graphql_api(query):
headers = {"Authorization": f"Bearer {github_token}"}
json_data = {'query': query}
response = requests.post('https://api.github.com/graphql', json=json_data, headers=headers)
json_data = {"query": query}
response = requests.post("https://api.github.com/graphql", json=json_data, headers=headers)
data = response.json()
return data
@ -234,21 +234,21 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
data = _call_graphql_api(query)
found_discussion_out_of_time_range = False
edges = data['data']['repository']['discussions']['edges']
edges = data["data"]["repository"]["discussions"]["edges"]
if len(edges) == 0:
break
else:
# keep the discussion whose author is not a member
for edge in edges:
# print the discussion title
discussion = edge['node']
discussion_updated_at = str2datetime(discussion['updatedAt'])
discussion = edge["node"]
discussion_updated_at = str2datetime(discussion["updatedAt"])
# check if the updatedAt is within the last 7 days
# if yes, add it to discussion_numbers
if discussion_updated_at > since:
if discussion['authorAssociation'] != 'MEMBER':
discussion_numbers.append(discussion['number'])
if discussion["authorAssociation"] != "MEMBER":
discussion_numbers.append(discussion["number"])
else:
found_discussion_out_of_time_range = True
@ -256,7 +256,7 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
break
else:
# update cursor
cursor = edges[-1]['cursor']
cursor = edges[-1]["cursor"]
# get the discussion comments and replies made by our member
user_engagement_count = {}
@ -269,42 +269,42 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
data = _call_graphql_api(query)
# get the comments
edges = data['data']['repository']['discussion']['comments']['edges']
edges = data["data"]["repository"]["discussion"]["comments"]["edges"]
# update the cursor
if len(edges) == 0:
break
else:
# update cursor for pagination
cursor = edges[-1]['cursor']
cursor = edges[-1]["cursor"]
for edge in edges:
comment = edge['node']
if comment['authorAssociation'] == 'MEMBER':
comment = edge["node"]
if comment["authorAssociation"] == "MEMBER":
# check if the updatedAt is within the last 7 days
# if yes, add it to user_engagement_count
comment_updated_at = datetime.strptime(comment['updatedAt'], "%Y-%m-%dT%H:%M:%SZ")
comment_updated_at = datetime.strptime(comment["updatedAt"], "%Y-%m-%dT%H:%M:%SZ")
if comment_updated_at > since:
member_name = comment['author']['login']
member_name = comment["author"]["login"]
if member_name in user_engagement_count:
user_engagement_count[member_name] += 1
else:
user_engagement_count[member_name] = 1
# get the replies
reply_edges = comment['replies']['edges']
reply_edges = comment["replies"]["edges"]
if len(reply_edges) == 0:
continue
else:
for reply_edge in reply_edges:
reply = reply_edge['node']
if reply['authorAssociation'] == 'MEMBER':
reply = reply_edge["node"]
if reply["authorAssociation"] == "MEMBER":
# check if the updatedAt is within the last 7 days
# if yes, add it to discussion_numbers
reply_updated_at = datetime.strptime(reply['updatedAt'], "%Y-%m-%dT%H:%M:%SZ")
reply_updated_at = datetime.strptime(reply["updatedAt"], "%Y-%m-%dT%H:%M:%SZ")
if reply_updated_at > since:
member_name = reply['author']['login']
member_name = reply["author"]["login"]
if member_name in user_engagement_count:
user_engagement_count[member_name] += 1
else:
@ -312,7 +312,9 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
return user_engagement_count
def generate_user_engagement_leaderboard_image(github_token: str, org_name: str, repo_list: List[str], output_path: str) -> bool:
def generate_user_engagement_leaderboard_image(
github_token: str, org_name: str, repo_list: List[str], output_path: str
) -> bool:
"""
Generate the user engagement leaderboard image for stats within the last 7 days
@ -335,16 +337,19 @@ def generate_user_engagement_leaderboard_image(github_token: str, org_name: str,
else:
total_engagement_count[name] = count
for repo_name in repo_list:
print(f"Fetching user engagement count for {repo_name}/{repo_name}")
issue_pr_engagement_count = get_issue_pull_request_comments(github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime_str)
discussion_engagement_count = get_discussion_comments(github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime)
issue_pr_engagement_count = get_issue_pull_request_comments(
github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime_str
)
discussion_engagement_count = get_discussion_comments(
github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime
)
# update the total engagement count
_update_count(issue_pr_engagement_count)
_update_count(discussion_engagement_count)
# prepare the data for plotting
x = []
y = []
@ -363,7 +368,7 @@ def generate_user_engagement_leaderboard_image(github_token: str, org_name: str,
# plot the leaderboard
xlabel = f"Number of Comments made (since {start_datetime_str})"
ylabel = "Member"
title = 'Active User Engagement Leaderboard'
title = "Active User Engagement Leaderboard"
plot_bar_chart(x, y, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path)
return True
else:
@ -380,16 +385,16 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou
"""
# request to the Github API to get the users who have contributed in the last 7 days
headers = {
'Authorization': f'Bearer {github_token}',
'Accept': 'application/vnd.github+json',
'X-GitHub-Api-Version': '2022-11-28'
"Authorization": f"Bearer {github_token}",
"Accept": "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
}
counter = Counter()
start_datetime = get_utc_time_one_week_ago()
def _get_url(org_name, repo_name, page):
return f'https://api.github.com/repos/{org_name}/{repo_name}/pulls?per_page=50&page={page}&state=closed'
return f"https://api.github.com/repos/{org_name}/{repo_name}/pulls?per_page=50&page={page}&state=closed"
def _iterate_by_page(org_name, repo_name):
page = 1
@ -415,8 +420,8 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou
# count the pull request and author from response
for pr_data in response:
merged_at = pr_data['merged_at']
author = pr_data['user']['login']
merged_at = pr_data["merged_at"]
author = pr_data["user"]["login"]
if merged_at is None:
continue
@ -439,7 +444,7 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou
_iterate_by_page(org_name, repo_name)
# convert unix timestamp to Beijing datetime
bj_start_datetime = datetime.fromtimestamp(start_datetime.timestamp(), tz=pytz.timezone('Asia/Shanghai'))
bj_start_datetime = datetime.fromtimestamp(start_datetime.timestamp(), tz=pytz.timezone("Asia/Shanghai"))
bj_start_datetime_str = datetime2str(bj_start_datetime)
contribution_list = counter.to_sorted_list()
@ -452,7 +457,7 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou
if len(author_list) > 0:
xlabel = f"Number of Pull Requests (since {bj_start_datetime_str})"
ylabel = "Contributor"
title = 'Active Contributor Leaderboard'
title = "Active Contributor Leaderboard"
plot_bar_chart(num_commit_list, author_list, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path)
return True
else:
@ -468,14 +473,14 @@ def upload_image_to_lark(lark_tenant_token: str, image_path: str) -> str:
image_path (str): the path to the image to be uploaded
"""
url = "https://open.feishu.cn/open-apis/im/v1/images"
form = {'image_type': 'message', 'image': (open(image_path, 'rb'))} # 需要替换具体的path
form = {"image_type": "message", "image": (open(image_path, "rb"))} # 需要替换具体的path
multi_form = MultipartEncoder(form)
headers = {
'Authorization': f'Bearer {lark_tenant_token}', ## 获取tenant_access_token, 需要替换为实际的token
"Authorization": f"Bearer {lark_tenant_token}", ## 获取tenant_access_token, 需要替换为实际的token
}
headers['Content-Type'] = multi_form.content_type
headers["Content-Type"] = multi_form.content_type
response = requests.request("POST", url, headers=headers, data=multi_form).json()
return response['data']['image_key']
return response["data"]["image_key"]
def generate_lark_tenant_access_token(app_id: str, app_secret: str) -> str:
@ -486,10 +491,10 @@ def generate_lark_tenant_access_token(app_id: str, app_secret: str) -> str:
app_id (str): Lark app id
app_secret (str): Lark app secret
"""
url = 'https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal'
data = {'app_id': app_id, 'app_secret': app_secret}
url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal"
data = {"app_id": app_id, "app_secret": app_secret}
response = requests.post(url, json=data).json()
return response['tenant_access_token']
return response["tenant_access_token"]
def send_image_to_lark(image_key: str, webhook_url: str) -> None:
@ -516,10 +521,10 @@ def send_message_to_lark(message: str, webhook_url: str):
requests.post(webhook_url, json=data)
if __name__ == '__main__':
GITHUB_TOKEN = os.environ['GITHUB_TOKEN']
CONTRIBUTOR_IMAGE_PATH = 'contributor_leaderboard.png'
USER_ENGAGEMENT_IMAGE_PATH = 'engagement_leaderboard.png'
if __name__ == "__main__":
GITHUB_TOKEN = os.environ["GITHUB_TOKEN"]
CONTRIBUTOR_IMAGE_PATH = "contributor_leaderboard.png"
USER_ENGAGEMENT_IMAGE_PATH = "engagement_leaderboard.png"
ORG_NAME = "hpcaitech"
# get all open source repositories
@ -527,17 +532,19 @@ if __name__ == '__main__':
# generate images
contrib_success = generate_contributor_leaderboard_image(GITHUB_TOKEN, ORG_NAME, REPO_LIST, CONTRIBUTOR_IMAGE_PATH)
engagement_success = generate_user_engagement_leaderboard_image(GITHUB_TOKEN, ORG_NAME, REPO_LIST, USER_ENGAGEMENT_IMAGE_PATH)
engagement_success = generate_user_engagement_leaderboard_image(
GITHUB_TOKEN, ORG_NAME, REPO_LIST, USER_ENGAGEMENT_IMAGE_PATH
)
# upload images
APP_ID = os.environ['LARK_APP_ID']
APP_SECRET = os.environ['LARK_APP_SECRET']
APP_ID = os.environ["LARK_APP_ID"]
APP_SECRET = os.environ["LARK_APP_SECRET"]
LARK_TENANT_TOKEN = generate_lark_tenant_access_token(app_id=APP_ID, app_secret=APP_SECRET)
contributor_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, CONTRIBUTOR_IMAGE_PATH)
user_engagement_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, USER_ENGAGEMENT_IMAGE_PATH)
# send message to lark
LARK_WEBHOOK_URL = os.environ['LARK_WEBHOOK_URL']
LARK_WEBHOOK_URL = os.environ["LARK_WEBHOOK_URL"]
message = """本周的社区榜单出炉啦!
1. 开发贡献者榜单
2. 用户互动榜单

View File

@ -7,27 +7,27 @@ import re
import requests
COMMIT_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/commits'
TAGS_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/tags'
COMMIT_API = "https://api.github.com/repos/hpcaitech/ColossalAI/commits"
TAGS_API = "https://api.github.com/repos/hpcaitech/ColossalAI/tags"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--out', type=str, help='output path for the release draft', required=True)
parser.add_argument('--version', type=str, help='current version to release', required=True)
parser.add_argument("--out", type=str, help="output path for the release draft", required=True)
parser.add_argument("--version", type=str, help="current version to release", required=True)
return parser.parse_args()
def get_latest_tag_commit(headers=None):
res = requests.get(url=TAGS_API, headers=headers)
data = res.json()
commit_hash = data[0]['commit']['sha']
version = data[0]['name']
commit_hash = data[0]["commit"]["sha"]
version = data[0]["name"]
return commit_hash, version
def get_commit_info(commit_hash, headers=None):
api = f'{COMMIT_API}/{commit_hash}'
api = f"{COMMIT_API}/{commit_hash}"
res = requests.get(url=api, headers=headers)
return res.json()
@ -37,7 +37,7 @@ def get_all_commit_info(since, headers=None):
results = []
while True:
api = f'{COMMIT_API}?since={since}&per_page=100&page={page}'
api = f"{COMMIT_API}?since={since}&per_page=100&page={page}"
resp = requests.get(url=api, headers=headers)
data = resp.json()
@ -53,21 +53,21 @@ def get_all_commit_info(since, headers=None):
def collate_release_info(commit_info_list):
results = dict()
pattern = pattern = r'\[.*\]'
pattern = pattern = r"\[.*\]"
for commit_info in commit_info_list:
author = commit_info['commit']['author']['name']
author = commit_info["commit"]["author"]["name"]
try:
author_url = commit_info['author']['url']
author_url = commit_info["author"]["url"]
except:
# author can be None
author_url = None
msg = commit_info['commit']['message']
msg = commit_info["commit"]["message"]
match = re.search(pattern, msg)
if match:
tag = match.group().lstrip('[').rstrip(']').capitalize()
tag = match.group().lstrip("[").rstrip("]").capitalize()
if tag not in results:
results[tag] = []
results[tag].append((msg, author, author_url))
@ -89,42 +89,43 @@ def generate_release_post_markdown(current_version, last_version, release_info):
for msg, author, author_url in v:
# only keep the first line
msg = msg.split('\n')[0]
msg = msg.split("\n")[0]
if author_url:
item = f'{msg} by [{author}]({author_url})\n'
item = f"{msg} by [{author}]({author_url})\n"
else:
item = f'{msg} by {author}\n'
text.append(f'- {item}')
item = f"{msg} by {author}\n"
text.append(f"- {item}")
text.append('\n')
text.append("\n")
# add full change log
text.append(
f'**Full Changelog**: https://github.com/hpcaitech/ColossalAI/compare/{current_version}...{last_version}')
f"**Full Changelog**: https://github.com/hpcaitech/ColossalAI/compare/{current_version}...{last_version}"
)
return text
if __name__ == '__main__':
if __name__ == "__main__":
args = parse_args()
token = os.environ['GITHUB_API_TOKEN']
headers = {'Authorization': token}
token = os.environ["GITHUB_API_TOKEN"]
headers = {"Authorization": token}
# get previous release tag
last_release_commit, last_version = get_latest_tag_commit(headers)
last_release_commit_info = get_commit_info(last_release_commit, headers=headers)
last_release_date = last_release_commit_info['commit']['author']['date']
last_release_date = last_release_commit_info["commit"]["author"]["date"]
# get the commits since last release
commit_info = get_all_commit_info(since=last_release_date, headers=headers)
commit_info = commit_info[:-1] # remove the release commit
commit_info = commit_info[:-1] # remove the release commit
# collate into markdown
release_info = collate_release_info(commit_info)
markdown_text = generate_release_post_markdown(args.version, last_version, release_info)
# write into a file
with open(args.out, 'w') as f:
with open(args.out, "w") as f:
for line in markdown_text:
f.write(line)

View File

@ -5,8 +5,8 @@ import requests
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--message', type=str)
parser.add_argument('-u', '--url', type=str)
parser.add_argument("-m", "--message", type=str)
parser.add_argument("-u", "--url", type=str)
return parser.parse_args()
@ -15,6 +15,6 @@ def send_message_to_lark(message, webhook_url):
requests.post(webhook_url, json=data)
if __name__ == '__main__':
if __name__ == "__main__":
args = parse_args()
send_message_to_lark(args.message, args.url)

View File

@ -3,3 +3,4 @@ line_length = 120
multi_line_output=3
include_trailing_comma = true
ignore_comments = true
profile = black

View File

@ -1,23 +1,31 @@
repos:
- repo: https://github.com/PyCQA/autoflake
rev: v2.2.1
hooks:
- id: autoflake
name: autoflake (python)
args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports']
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
name: sort all imports (python)
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.32.0
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.9.1
hooks:
- id: yapf
name: yapf formatter
args: ['--style=.style.yapf', '--parallel', '--in-place']
- id: black
name: black formatter
args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v13.0.1
hooks:
- id: clang-format
name: clang formatter
types_or: [c++, c]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0

View File

@ -1,5 +0,0 @@
[style]
based_on_style = google
spaces_before_comment = 4
split_before_logical_operator = true
column_limit = 120

View File

@ -27,7 +27,7 @@ def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
def preprocess_batch(samples) -> dict:
input_ids = torch.stack(samples)
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
return {'input_ids': input_ids, 'attention_mask': attention_mask}
return {"input_ids": input_ids, "attention_mask": attention_mask}
def print_rank_0(*args, **kwargs) -> None:
@ -39,32 +39,32 @@ def print_model_numel(model_dict: dict) -> None:
B = 1024**3
M = 1024**2
K = 1024
outputs = ''
outputs = ""
for name, numel in model_dict.items():
outputs += f'{name}: '
outputs += f"{name}: "
if numel >= B:
outputs += f'{numel / B:.2f} B\n'
outputs += f"{numel / B:.2f} B\n"
elif numel >= M:
outputs += f'{numel / M:.2f} M\n'
outputs += f"{numel / M:.2f} M\n"
elif numel >= K:
outputs += f'{numel / K:.2f} K\n'
outputs += f"{numel / K:.2f} K\n"
else:
outputs += f'{numel}\n'
outputs += f"{numel}\n"
print_rank_0(outputs)
def get_gpt_config(model_name: str) -> OPTConfig:
model_map = {
'125m': OPTConfig.from_pretrained('facebook/opt-125m'),
'350m': OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16),
'700m': OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20),
'1.3b': OPTConfig.from_pretrained('facebook/opt-1.3b'),
'2.7b': OPTConfig.from_pretrained('facebook/opt-2.7b'),
'3.5b': OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32),
'5.5b': OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32),
'6.7b': OPTConfig.from_pretrained('facebook/opt-6.7b'),
'10b': OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32),
'13b': OPTConfig.from_pretrained('facebook/opt-13b'),
"125m": OPTConfig.from_pretrained("facebook/opt-125m"),
"350m": OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16),
"700m": OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20),
"1.3b": OPTConfig.from_pretrained("facebook/opt-1.3b"),
"2.7b": OPTConfig.from_pretrained("facebook/opt-2.7b"),
"3.5b": OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32),
"5.5b": OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32),
"6.7b": OPTConfig.from_pretrained("facebook/opt-6.7b"),
"10b": OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32),
"13b": OPTConfig.from_pretrained("facebook/opt-13b"),
}
try:
return model_map[model_name]
@ -73,20 +73,20 @@ def get_gpt_config(model_name: str) -> OPTConfig:
def main(args):
if args.strategy == 'ddp':
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
elif args.strategy == 'colossalai_gemini_cpu':
strategy = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
elif args.strategy == 'colossalai_zero2':
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
elif args.strategy == 'colossalai_zero2_cpu':
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
elif args.strategy == 'colossalai_zero1':
strategy = LowLevelZeroStrategy(stage=1, placement_policy='cuda')
elif args.strategy == 'colossalai_zero1_cpu':
strategy = LowLevelZeroStrategy(stage=1, placement_policy='cpu')
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
elif args.strategy == "colossalai_gemini_cpu":
strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif args.strategy == "colossalai_zero2_cpu":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
elif args.strategy == "colossalai_zero1":
strategy = LowLevelZeroStrategy(stage=1, placement_policy="cuda")
elif args.strategy == "colossalai_zero1_cpu":
strategy = LowLevelZeroStrategy(stage=1, placement_policy="cpu")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
@ -103,90 +103,106 @@ def main(args):
if args.use_kernels:
from coati.kernels import convert_to_xformer_model
actor, critic, initial_model, reward_model = map(convert_to_xformer_model,
(actor, critic, initial_model, reward_model))
actor, critic, initial_model, reward_model = map(
convert_to_xformer_model, (actor, critic, initial_model, reward_model)
)
actor_numel = get_model_numel(actor, strategy)
critic_numel = get_model_numel(critic, strategy)
initial_model_numel = get_model_numel(initial_model, strategy)
reward_model_numel = get_model_numel(reward_model, strategy)
print_model_numel({
'Actor': actor_numel,
'Critic': critic_numel,
'Initial model': initial_model_numel,
'Reward model': reward_model_numel
})
performance_evaluator = PerformanceEvaluator(actor_numel,
critic_numel,
initial_model_numel,
reward_model_numel,
enable_grad_checkpoint=False,
ignore_episodes=1)
print_model_numel(
{
"Actor": actor_numel,
"Critic": critic_numel,
"Initial model": initial_model_numel,
"Reward model": reward_model_numel,
}
)
performance_evaluator = PerformanceEvaluator(
actor_numel,
critic_numel,
initial_model_numel,
reward_model_numel,
enable_grad_checkpoint=False,
ignore_episodes=1,
)
if args.strategy.startswith('colossalai'):
if args.strategy.startswith("colossalai"):
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
else:
actor_optim = Adam(actor.parameters(), lr=5e-6)
critic_optim = Adam(critic.parameters(), lr=5e-6)
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device())
dataloader = DataLoader(random_prompts,
batch_size=args.experience_batch_size,
shuffle=True,
collate_fn=preprocess_batch)
dataloader = DataLoader(
random_prompts, batch_size=args.experience_batch_size, shuffle=True, collate_fn=preprocess_batch
)
trainer = PPOTrainer(strategy,
actor,
critic,
reward_model,
initial_model,
actor_optim,
critic_optim,
ptx_coef=0,
train_batch_size=args.train_batch_size,
offload_inference_models=args.offload_inference_models,
max_length=512,
do_sample=True,
temperature=1.0,
top_k=50,
use_cache=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
callbacks=[performance_evaluator])
trainer = PPOTrainer(
strategy,
actor,
critic,
reward_model,
initial_model,
actor_optim,
critic_optim,
ptx_coef=0,
train_batch_size=args.train_batch_size,
offload_inference_models=args.offload_inference_models,
max_length=512,
do_sample=True,
temperature=1.0,
top_k=50,
use_cache=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
callbacks=[performance_evaluator],
)
trainer.fit(prompt_dataloader=dataloader,
pretrain_dataloader=None,
num_episodes=args.num_episodes,
num_update_steps=args.num_update_steps,
num_collect_steps=args.num_collect_steps)
trainer.fit(
prompt_dataloader=dataloader,
pretrain_dataloader=None,
num_episodes=args.num_episodes,
num_update_steps=args.num_update_steps,
num_collect_steps=args.num_collect_steps,
)
print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
print_rank_0(f"Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB")
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='125m')
parser.add_argument('--critic_model', default='125m')
parser.add_argument('--strategy',
choices=[
'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2',
'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu'
],
default='ddp')
parser.add_argument('--num_episodes', type=int, default=3)
parser.add_argument('--num_collect_steps', type=int, default=8)
parser.add_argument('--num_update_steps', type=int, default=1)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0)
parser.add_argument('--cuda_mem_frac', type=float, default=1.0)
parser.add_argument('--offload_inference_models', action='store_true', default=False)
parser.add_argument('--use_kernels', action='store_true', default=False)
parser.add_argument("--model", default="125m")
parser.add_argument("--critic_model", default="125m")
parser.add_argument(
"--strategy",
choices=[
"ddp",
"colossalai_gemini",
"colossalai_gemini_cpu",
"colossalai_zero2",
"colossalai_zero2_cpu",
"colossalai_zero1",
"colossalai_zero1_cpu",
],
default="ddp",
)
parser.add_argument("--num_episodes", type=int, default=3)
parser.add_argument("--num_collect_steps", type=int, default=8)
parser.add_argument("--num_update_steps", type=int, default=1)
parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0)
parser.add_argument("--cuda_mem_frac", type=float, default=1.0)
parser.add_argument("--offload_inference_models", action="store_true", default=False)
parser.add_argument("--use_kernels", action="store_true", default=False)
args = parser.parse_args()
main(args)

View File

@ -22,13 +22,13 @@ from transformers.modeling_utils import no_init_weights
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
s.bind(("", 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(('8.8.8.8', 80))
s.connect(("8.8.8.8", 80))
return s.getsockname()[0]
@ -36,22 +36,25 @@ def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
env_info_trainers = [{
'local_rank': '0',
'rank': str(rank),
'world_size': str(args.num_trainers),
'master_port': trainer_port,
'master_addr': master_addr
} for rank in range(args.num_trainers)]
env_info_trainers = [
{
"local_rank": "0",
"rank": str(rank),
"world_size": str(args.num_trainers),
"master_port": trainer_port,
"master_addr": master_addr,
}
for rank in range(args.num_trainers)
]
# maker_env_info
maker_port = str(get_free_port())
env_info_maker = {
'local_rank': '0',
'rank': '0',
'world_size': '1',
'master_port': maker_port,
'master_addr': master_addr
"local_rank": "0",
"rank": "0",
"world_size": "1",
"master_port": maker_port,
"master_addr": master_addr,
}
# configure tokenizer
@ -63,21 +66,27 @@ def main(args):
critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
reward_model = get_reward_model_from_args(args.critic_model,
config=critic_cfg).requires_grad_(False).half().cuda()
if args.initial_model_quant_ckpt is not None and args.model == 'llama':
reward_model = (
get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
)
if args.initial_model_quant_ckpt is not None and args.model == "llama":
# quantize initial model
with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg)
initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
args.quant_group_size).cuda().requires_grad_(False)
initial_model.model = (
llama_load_quant(
initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
)
.cuda()
.requires_grad_(False)
)
else:
initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model
# configure Experience Maker
experience_holder_ref = ExperienceMakerHolder.options(name="maker0", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)],
detached_trainer_name_list=[f"trainer{i}" for i in range(args.num_trainers)],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
model_fn=model_fn,
env_info=env_info_maker,
@ -97,15 +106,18 @@ def main(args):
def trainer_model_fn():
actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
critic = get_critic_from_args(args.critic_model,
config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda()
critic = (
get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain))
.half()
.cuda()
)
return actor, critic
# configure Trainer
trainer_refs = [
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
experience_maker_holder_name_list=[
f'maker{x}' for x in get_receivers_per_sender(i, args.num_trainers, 1, allow_idle_sender=True)
f"maker{x}" for x in get_receivers_per_sender(i, args.num_trainers, 1, allow_idle_sender=True)
],
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
model_fn=trainer_model_fn,
@ -114,7 +126,8 @@ def main(args):
buffer_limit=16,
eval_performance=True,
debug=args.debug,
) for i, env_info_trainer in enumerate(env_info_trainers)
)
for i, env_info_trainer in enumerate(env_info_trainers)
]
dataset_size = args.experience_batch_size * 4
@ -122,7 +135,7 @@ def main(args):
def data_gen_fn():
input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
attn_mask = torch.ones_like(input_ids)
return {'input_ids': input_ids, 'attention_mask': attn_mask}
return {"input_ids": input_ids, "attention_mask": attn_mask}
def build_dataloader(size):
dataset = [data_gen_fn() for _ in range(size)]
@ -138,8 +151,10 @@ def main(args):
wait_tasks = []
wait_tasks.append(
experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size),
num_steps=args.experience_steps))
experience_holder_ref.workingloop.remote(
partial(build_dataloader, dataset_size), num_steps=args.experience_steps
)
)
total_steps = args.experience_batch_size * args.experience_steps // (args.num_trainers * args.train_batch_size)
for trainer_ref in trainer_refs:
@ -148,31 +163,30 @@ def main(args):
ray.get(wait_tasks)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--num_trainers', type=int, default=1)
parser.add_argument('--trainer_strategy',
choices=[
'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
'colossalai_zero2_cpu'
],
default='ddp')
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--critic_pretrain', type=str, default=None)
parser.add_argument('--experience_steps', type=int, default=4)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--train_epochs', type=int, default=1)
parser.add_argument('--update_steps', type=int, default=2)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--num_trainers", type=int, default=1)
parser.add_argument(
"--trainer_strategy",
choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
default="ddp",
)
parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--critic_pretrain", type=str, default=None)
parser.add_argument("--experience_steps", type=int, default=4)
parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument("--train_epochs", type=int, default=1)
parser.add_argument("--update_steps", type=int, default=2)
parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
parser.add_argument('--quant_bits', type=int, default=4)
parser.add_argument('--quant_group_size', type=int, default=128)
parser.add_argument('--debug', action='store_true')
parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
parser.add_argument("--quant_bits", type=int, default=4)
parser.add_argument("--quant_group_size", type=int, default=128)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
main(args)

View File

@ -22,13 +22,13 @@ from transformers.modeling_utils import no_init_weights
def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
s.bind(("", 0))
return s.getsockname()[1]
def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(('8.8.8.8', 80))
s.connect(("8.8.8.8", 80))
return s.getsockname()[0]
@ -36,23 +36,29 @@ def main(args):
master_addr = str(get_local_ip())
# trainer_env_info
trainer_port = str(get_free_port())
env_info_trainers = [{
'local_rank': '0',
'rank': str(rank),
'world_size': str(args.num_trainers),
'master_port': trainer_port,
'master_addr': master_addr
} for rank in range(args.num_trainers)]
env_info_trainers = [
{
"local_rank": "0",
"rank": str(rank),
"world_size": str(args.num_trainers),
"master_port": trainer_port,
"master_addr": master_addr,
}
for rank in range(args.num_trainers)
]
# maker_env_info
maker_port = str(get_free_port())
env_info_makers = [{
'local_rank': '0',
'rank': str(rank),
'world_size': str(args.num_makers),
'master_port': maker_port,
'master_addr': master_addr
} for rank in range(args.num_makers)]
env_info_makers = [
{
"local_rank": "0",
"rank": str(rank),
"world_size": str(args.num_makers),
"master_port": maker_port,
"master_addr": master_addr,
}
for rank in range(args.num_makers)
]
# configure tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
@ -63,14 +69,20 @@ def main(args):
critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
reward_model = get_reward_model_from_args(args.critic_model,
config=critic_cfg).requires_grad_(False).half().cuda()
if args.initial_model_quant_ckpt is not None and args.model == 'llama':
reward_model = (
get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
)
if args.initial_model_quant_ckpt is not None and args.model == "llama":
# quantize initial model
with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg)
initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits,
args.quant_group_size).cuda().requires_grad_(False)
initial_model.model = (
llama_load_quant(
initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
)
.cuda()
.requires_grad_(False)
)
else:
initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model
@ -79,7 +91,7 @@ def main(args):
experience_holder_refs = [
ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=[
f'trainer{x}'
f"trainer{x}"
for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False)
],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
@ -103,8 +115,11 @@ def main(args):
def trainer_model_fn():
actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
critic = get_critic_from_args(args.critic_model,
config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda()
critic = (
get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain))
.half()
.cuda()
)
return actor, critic
# configure Trainer
@ -130,7 +145,7 @@ def main(args):
def data_gen_fn():
input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
attn_mask = torch.ones_like(input_ids)
return {'input_ids': input_ids, 'attention_mask': attn_mask}
return {"input_ids": input_ids, "attention_mask": attn_mask}
def build_dataloader(size):
dataset = [data_gen_fn() for _ in range(size)]
@ -147,43 +162,48 @@ def main(args):
for experience_holder_ref in experience_holder_refs:
wait_tasks.append(
experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size),
num_steps=args.experience_steps))
experience_holder_ref.workingloop.remote(
partial(build_dataloader, dataset_size), num_steps=args.experience_steps
)
)
total_steps = args.experience_batch_size * args.experience_steps * \
args.num_makers // (args.num_trainers * args.train_batch_size)
total_steps = (
args.experience_batch_size
* args.experience_steps
* args.num_makers
// (args.num_trainers * args.train_batch_size)
)
for trainer_ref in trainer_refs:
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
ray.get(wait_tasks)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--num_makers', type=int, default=1)
parser.add_argument('--num_trainers', type=int, default=1)
parser.add_argument('--trainer_strategy',
choices=[
'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu',
'colossalai_zero2_cpu'
],
default='ddp')
parser.add_argument('--maker_strategy', choices=['naive'], default='naive')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--critic_pretrain', type=str, default=None)
parser.add_argument('--experience_steps', type=int, default=4)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--train_epochs', type=int, default=1)
parser.add_argument('--update_steps', type=int, default=2)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--num_makers", type=int, default=1)
parser.add_argument("--num_trainers", type=int, default=1)
parser.add_argument(
"--trainer_strategy",
choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
default="ddp",
)
parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--critic_pretrain", type=str, default=None)
parser.add_argument("--experience_steps", type=int, default=4)
parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument("--train_epochs", type=int, default=1)
parser.add_argument("--update_steps", type=int, default=2)
parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--initial_model_quant_ckpt', type=str, default=None)
parser.add_argument('--quant_bits', type=int, default=4)
parser.add_argument('--quant_group_size', type=int, default=128)
parser.add_argument('--debug', action='store_true')
parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
parser.add_argument("--quant_bits", type=int, default=4)
parser.add_argument("--quant_group_size", type=int, default=128)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
main(args)

View File

@ -4,7 +4,10 @@ from .sft_dataset import SFTDataset, SupervisedDataset
from .utils import is_rank_0
__all__ = [
'RmStaticDataset', 'HhRlhfDataset',
'SFTDataset', 'SupervisedDataset',
'PromptDataset', 'is_rank_0',
"RmStaticDataset",
"HhRlhfDataset",
"SFTDataset",
"SupervisedDataset",
"PromptDataset",
"is_rank_0",
]

View File

@ -49,7 +49,7 @@ class Conversation:
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append([msg, None])
else:
@ -57,12 +57,14 @@ class Conversation:
return ret
def copy(self):
return Conversation(system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep)
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
)
def dict(self):
return {
@ -70,7 +72,7 @@ class Conversation:
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep
"sep": self.sep,
}

View File

@ -13,11 +13,13 @@ from .utils import jload
class PromptDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self,
data_path: str,
tokenizer: transformers.PreTrainedTokenizer,
max_datasets_size: int = None,
max_length: int = 96):
def __init__(
self,
data_path: str,
tokenizer: transformers.PreTrainedTokenizer,
max_datasets_size: int = None,
max_length: int = 96,
):
super(PromptDataset, self).__init__()
self.keyed_prompt = defaultdict(list)
self.logger = get_dist_logger()
@ -30,11 +32,9 @@ class PromptDataset(Dataset):
list_data_dict = list_data_dict[:max_datasets_size]
instructions = [data_dict["instruction"] for data_dict in list_data_dict]
tokens = tokenizer(instructions,
return_tensors='pt',
max_length=max_length,
padding='max_length',
truncation=True)
tokens = tokenizer(
instructions, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True
)
for k, tensor in tokens.items():
self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind()

View File

@ -20,44 +20,31 @@ class RmStaticDataset(Dataset):
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
self.end_token = tokenizer.eos_token \
if special_token is None else special_token
self.end_token = tokenizer.eos_token if special_token is None else special_token
chosen = [
data["prompt"] + data["chosen"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
]
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.chosen = {
"input_ids": chosen_token["input_ids"],
"attention_mask": chosen_token["attention_mask"]
}
chosen = [data["prompt"] + data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
chosen_token = tokenizer(
chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
reject = [
data["prompt"] + data["rejected"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
]
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject = {
"input_ids": reject_token["input_ids"],
"attention_mask": reject_token["attention_mask"]
}
reject = [data["prompt"] + data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
reject_token = tokenizer(
reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
def __len__(self):
length = self.chosen["input_ids"].shape[0]
return length
def __getitem__(self, idx):
return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
return (
self.chosen["input_ids"][idx],
self.chosen["attention_mask"][idx],
self.reject["input_ids"][idx],
self.reject["attention_mask"][idx],
)
# Anthropic/hh-rlhf
@ -74,41 +61,28 @@ class HhRlhfDataset(Dataset):
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__()
self.end_token = tokenizer.eos_token \
if special_token is None else special_token
self.end_token = tokenizer.eos_token if special_token is None else special_token
chosen = [
data["chosen"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
]
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.chosen = {
"input_ids": chosen_token["input_ids"],
"attention_mask": chosen_token["attention_mask"]
}
chosen = [data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
chosen_token = tokenizer(
chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
reject = [
data["rejected"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
]
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject = {
"input_ids": reject_token["input_ids"],
"attention_mask": reject_token["attention_mask"]
}
reject = [data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
reject_token = tokenizer(
reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
def __len__(self):
length = self.chosen["input_ids"].shape[0]
return length
def __getitem__(self, idx):
return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \
self.reject["input_ids"][idx], self.reject["attention_mask"][idx]
return (
self.chosen["input_ids"][idx],
self.chosen["attention_mask"][idx],
self.reject["input_ids"][idx],
self.reject["attention_mask"][idx],
)

View File

@ -16,10 +16,11 @@ import copy
from typing import Dict, Sequence, Tuple
import torch
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import PreTrainedTokenizer
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from colossalai.logging import get_dist_logger
from .utils import is_rank_0, jload
@ -28,32 +29,33 @@ logger = get_dist_logger()
IGNORE_INDEX = -100
PROMPT_DICT = {
"prompt_input": ("Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
"prompt_no_input": ("Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"),
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}
def _preprocess(sources: Sequence[str],
targets: Sequence[str],
tokenizer: PreTrainedTokenizer,
max_length: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def _preprocess(
sources: Sequence[str],
targets: Sequence[str],
tokenizer: PreTrainedTokenizer,
max_length: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Preprocess the data by tokenizing."""
sequences = [s + t for s, t in zip(sources, targets)]
sequences_token = tokenizer(sequences,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
sources_token = tokenizer(sources,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
sequences_token = tokenizer(
sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
sources_token = tokenizer(
sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
labels = copy.deepcopy(sequences_token["input_ids"])
for i in range(labels.shape[0]):
@ -64,23 +66,24 @@ def _preprocess(sources: Sequence[str],
labels[i][:source_len] = IGNORE_INDEX
elif tokenizer.padding_side == "left":
# |pad|prompt|completion|eos|
labels[i][pad_len:pad_len + source_len] = IGNORE_INDEX
labels[i][pad_len : pad_len + source_len] = IGNORE_INDEX
else:
raise RuntimeError()
return sequences_token["input_ids"], labels, sequences_token["attention_mask"]
def _preprocess_chatglm(sources: Sequence[str],
targets: Sequence[str],
tokenizer: PreTrainedTokenizer,
max_length: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def _preprocess_chatglm(
sources: Sequence[str],
targets: Sequence[str],
tokenizer: PreTrainedTokenizer,
max_length: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Preprocess the data by tokenizing.
None for attention mask, ChatGLM will calculate attention mask according to input ids
"""
labels = []
input_ids = []
for source, target in zip(sources, targets):
@ -90,16 +93,16 @@ def _preprocess_chatglm(sources: Sequence[str],
# truncate
sp_token_list = [tokenizer.gmask_token_id, tokenizer.bos_token_id]
truncate_length = max(0, len(input_id) - max_length)
input_id = input_id[truncate_length: ]
input_id = input_id[truncate_length:]
if truncate_length == len(source_id) + 1:
input_id = sp_token_list + input_id[1: ]
input_id = sp_token_list + input_id[1:]
elif truncate_length > len(source_id) + 1:
input_id = sp_token_list + input_id[2: ]
input_id = sp_token_list + input_id[2:]
context_length = input_id.index(tokenizer.bos_token_id)
mask_position = context_length - 1
label = [IGNORE_INDEX] * context_length + input_id[mask_position+1:]
label = [IGNORE_INDEX] * context_length + input_id[mask_position + 1 :]
pad_len = max_length - len(input_id)
input_id = input_id + [tokenizer.pad_token_id] * pad_len
input_ids.append(input_id)
@ -117,25 +120,18 @@ class SFTDataset(Dataset):
max_length: max length of input
"""
def __init__(self,
dataset: Dict,
tokenizer: PreTrainedTokenizer,
max_length: int = 512
) -> None:
def __init__(self, dataset: Dict, tokenizer: PreTrainedTokenizer, max_length: int = 512) -> None:
super().__init__()
self.input_ids = []
sources = [data["prompt"] for data in dataset]
targets = [
data["completion"] + tokenizer.eos_token
for data in tqdm(dataset, disable=not is_rank_0())
]
targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())]
if isinstance(tokenizer, ChatGLMTokenizer):
self.input_ids, self.labels, self.attention_mask = \
_preprocess_chatglm(sources, targets, tokenizer, max_length)
self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
sources, targets, tokenizer, max_length
)
else:
self.input_ids, self.labels, self.attention_mask = \
_preprocess(sources, targets, tokenizer, max_length)
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
def __len__(self):
length = self.input_ids.shape[0]
@ -143,22 +139,17 @@ class SFTDataset(Dataset):
def __getitem__(self, idx):
if self.attention_mask is not None:
return dict(input_ids=self.input_ids[idx],
labels=self.labels[idx],
attention_mask=self.attention_mask[idx])
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
else:
return dict(input_ids=self.input_ids[idx],
labels=self.labels[idx])
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self,
data_path: str,
tokenizer: PreTrainedTokenizer,
max_datasets_size: int = None,
max_length: int = 512):
def __init__(
self, data_path: str, tokenizer: PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512
):
super().__init__()
logger.info("Loading data...")
list_data_dict = jload(data_path)
@ -174,18 +165,15 @@ class SupervisedDataset(Dataset):
prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example)
for example in list_data_dict
]
targets = [
example['output'] + tokenizer.eos_token
for example in list_data_dict
]
targets = [example["output"] + tokenizer.eos_token for example in list_data_dict]
logger.info("Tokenizing inputs... This may take some time...")
if isinstance(tokenizer, ChatGLMTokenizer):
self.input_ids, self.labels, self.attention_mask = \
_preprocess_chatglm(sources, targets, tokenizer, max_length)
self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
sources, targets, tokenizer, max_length
)
else:
self.input_ids, self.labels, self.attention_mask = \
_preprocess(sources, targets, tokenizer, max_length)
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
def __len__(self):
length = self.input_ids.shape[0]
@ -193,9 +181,6 @@ class SupervisedDataset(Dataset):
def __getitem__(self, idx):
if self.attention_mask is not None:
return dict(input_ids=self.input_ids[idx],
labels=self.labels[idx],
attention_mask=self.attention_mask[idx])
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
else:
return dict(input_ids=self.input_ids[idx],
labels=self.labels[idx])
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])

View File

@ -1,4 +1,4 @@
from .base import ExperienceBuffer
from .naive import NaiveExperienceBuffer
__all__ = ['ExperienceBuffer', 'NaiveExperienceBuffer']
__all__ = ["ExperienceBuffer", "NaiveExperienceBuffer"]

View File

@ -7,9 +7,9 @@ from coati.experience_maker.base import Experience
class ExperienceBuffer(ABC):
"""Experience buffer base class. It stores experience.
Args:
sample_batch_size (int): Batch size when sampling.
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
Args:
sample_batch_size (int): Batch size when sampling.
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
"""
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:

View File

@ -11,23 +11,23 @@ from .utils import BufferItem, make_experience_batch, split_experience_batch
class NaiveExperienceBuffer(ExperienceBuffer):
"""Naive experience buffer class. It stores experience.
Args:
sample_batch_size (int): Batch size when sampling.
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True.
Args:
sample_batch_size (int): Batch size when sampling.
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True.
"""
def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None:
super().__init__(sample_batch_size, limit)
self.cpu_offload = cpu_offload
self.target_device = torch.device(f'cuda:{torch.cuda.current_device()}')
self.target_device = torch.device(f"cuda:{torch.cuda.current_device()}")
# TODO(ver217): add prefetch
self.items: List[BufferItem] = []
@torch.no_grad()
def append(self, experience: Experience) -> None:
if self.cpu_offload:
experience.to_device(torch.device('cpu'))
experience.to_device(torch.device("cpu"))
items = split_experience_batch(experience)
self.items.extend(items)
if self.limit > 0:

View File

@ -21,6 +21,7 @@ class BufferItem:
"A" is the number of actions.
"""
sequences: torch.Tensor
action_log_probs: torch.Tensor
values: torch.Tensor
@ -33,8 +34,7 @@ class BufferItem:
def split_experience_batch(experience: Experience) -> List[BufferItem]:
batch_size = experience.sequences.size(0)
batch_kwargs = [{} for _ in range(batch_size)]
keys = ('sequences', 'action_log_probs', 'values',
'reward', 'advantages', 'attention_mask', 'action_mask')
keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask")
for key in keys:
value = getattr(experience, key)
if isinstance(value, torch.Tensor):
@ -49,22 +49,21 @@ def split_experience_batch(experience: Experience) -> List[BufferItem]:
return items
def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
assert side in ('left', 'right')
def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left") -> torch.Tensor:
assert side in ("left", "right")
max_len = max(seq.size(0) for seq in sequences)
padded_sequences = []
for seq in sequences:
pad_len = max_len - seq.size(0)
padding = (pad_len, 0) if side == 'left' else (0, pad_len)
padding = (pad_len, 0) if side == "left" else (0, pad_len)
padded_sequences.append(F.pad(seq, padding))
return torch.stack(padded_sequences, dim=0)
def make_experience_batch(items: List[BufferItem]) -> Experience:
kwargs = {}
to_pad_keys = set(('action_log_probs', 'action_mask'))
keys = ('sequences', 'action_log_probs', 'values',
'reward', 'advantages', 'attention_mask', 'action_mask')
to_pad_keys = set(("action_log_probs", "action_mask"))
keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask")
for key in keys:
vals = [getattr(item, key) for item in items]
if key in to_pad_keys:

View File

@ -1,4 +1,4 @@
from .base import Experience, ExperienceMaker
from .naive import NaiveExperienceMaker
__all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker']
__all__ = ["Experience", "ExperienceMaker", "NaiveExperienceMaker"]

View File

@ -24,6 +24,7 @@ class Experience:
"A" is the number of actions.
"""
sequences: torch.Tensor
action_log_probs: torch.Tensor
values: torch.Tensor
@ -58,13 +59,9 @@ class Experience:
class ExperienceMaker(ABC):
def __init__(self,
actor: Actor,
critic: nn.Module,
reward_model: nn.Module,
initial_model: Actor,
kl_coef: float = 0.1) -> None:
def __init__(
self, actor: Actor, critic: nn.Module, reward_model: nn.Module, initial_model: Actor, kl_coef: float = 0.1
) -> None:
super().__init__()
self.actor = actor
self.critic = critic

View File

@ -23,22 +23,21 @@ class NaiveExperienceMaker(ExperienceMaker):
# calculate auxiliary tensors
attention_mask = None
pad_token_id = generate_kwargs.get('pad_token_id', None)
pad_token_id = generate_kwargs.get("pad_token_id", None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id)\
.to(dtype=torch.long, device=sequences.device)
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
input_len = input_ids.size(1)
eos_token_id = generate_kwargs.get('eos_token_id', None)
eos_token_id = generate_kwargs.get("eos_token_id", None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
action_mask = action_mask[:, -(sequences.size(1) - input_len):]
action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
num_actions = action_mask.size(1)
actor_output = self.actor(sequences, attention_mask)

View File

@ -1,6 +1,6 @@
from .wrapper import convert_to_xformer_model, recover_from_xformer_model
__all__ = [
'convert_to_xformer_model',
'recover_from_xformer_model',
"convert_to_xformer_model",
"recover_from_xformer_model",
]

View File

@ -21,11 +21,12 @@ class XOPTAttention(OPTAttention):
output_attentions: bool = False,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor]]]:
if not self.training:
return super().forward(hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask,
output_attentions)
return super().forward(
hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, output_attentions
)
"""Input shape: Batch x Time x Channel"""
assert layer_head_mask is None, 'Xformers attention does not support layer_head_mask'
assert not output_attentions, 'Xformers attention does not support output_attentions'
assert layer_head_mask is None, "Xformers attention does not support layer_head_mask"
assert not output_attentions, "Xformers attention does not support output_attentions"
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
@ -69,12 +70,14 @@ class XOPTAttention(OPTAttention):
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
attn_output = xops.memory_efficient_attention(query_states,
key_states,
value_states,
attn_bias=xops.LowerTriangularMask(),
p=self.dropout if self.training else 0.0,
scale=self.scaling)
attn_output = xops.memory_efficient_attention(
query_states,
key_states,
value_states,
attn_bias=xops.LowerTriangularMask(),
p=self.dropout if self.training else 0.0,
scale=self.scaling,
)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.

View File

@ -3,6 +3,13 @@ from .lora import LoRAModule, convert_to_lora_module
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
__all__ = [
'Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'LogSigLoss', 'LogExpLoss',
'LoRAModule', 'convert_to_lora_module'
"Actor",
"Critic",
"RewardModel",
"PolicyLoss",
"ValueLoss",
"LogSigLoss",
"LogExpLoss",
"LoRAModule",
"convert_to_lora_module",
]

View File

@ -9,7 +9,7 @@ from .reward_model import RewardModel
def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module:
"""Get the base model of our wrapper classes.
For Actor, Critic and RewardModel, return ``model.model``,
For Actor, Critic and RewardModel, return ``model.model``,
it's usually a ``transformers.PreTrainedModel``.
Args:
@ -18,9 +18,10 @@ def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module:
Returns:
nn.Module: the base model
"""
assert isinstance(model, (Actor, Critic, RewardModel)), \
f'Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first.'
assert isinstance(
model, (Actor, Critic, RewardModel)
), f"Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first."
return model.model
__all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model']
__all__ = ["Actor", "Critic", "RewardModel", "get_base_model"]

View File

@ -16,18 +16,17 @@ class Actor(LoRAModule):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none") -> None:
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model
self.convert_to_lora()
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
**model_kwargs, # HACK: `generate` method may pass more kwargs
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
**model_kwargs, # HACK: `generate` method may pass more kwargs
) -> torch.Tensor:
"""Returns model output.
"""
"""Returns model output."""
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
return output

View File

@ -23,22 +23,23 @@ class Critic(LoRAModule):
model: nn.Module,
value_head: nn.Module,
lora_rank: int = 0,
lora_train_bias: str = 'none',
lora_train_bias: str = "none",
use_action_mask: bool = False,
) -> None:
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model
self.value_head = value_head
self.use_action_mask = use_action_mask
self.convert_to_lora()
def forward(self,
sequences: torch.LongTensor,
action_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(
self,
sequences: torch.LongTensor,
action_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
outputs = self.model(sequences, attention_mask=attention_mask)
last_hidden_states = outputs['last_hidden_state']
last_hidden_states = outputs["last_hidden_state"]
values = self.value_head(last_hidden_states).squeeze(-1)

View File

@ -17,11 +17,13 @@ class RewardModel(LoRAModule):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
model: nn.Module,
value_head: Optional[nn.Module] = None,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
def __init__(
self,
model: nn.Module,
value_head: Optional[nn.Module] = None,
lora_rank: int = 0,
lora_train_bias: str = "none",
) -> None:
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model
self.convert_to_lora()
@ -35,7 +37,7 @@ class RewardModel(LoRAModule):
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
outputs = self.model(sequences, attention_mask=attention_mask)
last_hidden_states = outputs['last_hidden_state']
last_hidden_states = outputs["last_hidden_state"]
values = self.value_head(last_hidden_states)[:, :-1]
value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
return value

View File

@ -2,4 +2,4 @@ from .bloom_actor import BLOOMActor
from .bloom_critic import BLOOMCritic
from .bloom_rm import BLOOMRM
__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM']
__all__ = ["BLOOMActor", "BLOOMCritic", "BLOOMRM"]

View File

@ -1,7 +1,6 @@
from typing import Optional
import torch
from transformers import BloomConfig, BloomForCausalLM, BloomModel
from transformers import BloomConfig, BloomForCausalLM
from ..base import Actor
@ -18,12 +17,14 @@ class BLOOMActor(Actor):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
def __init__(
self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = "none",
) -> None:
if pretrained is not None:
model = BloomForCausalLM.from_pretrained(pretrained)
elif config is not None:

View File

@ -1,8 +1,7 @@
from typing import Optional
import torch
import torch.nn as nn
from transformers import BloomConfig, BloomForCausalLM, BloomModel
from transformers import BloomConfig, BloomModel
from ..base import Critic
@ -18,12 +17,14 @@ class BLOOMCritic(Critic):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
def __init__(
self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
lora_rank: int = 0,
lora_train_bias: str = "none",
**kwargs,
) -> None:
if pretrained is not None:
model = BloomModel.from_pretrained(pretrained)
elif config is not None:

View File

@ -1,7 +1,7 @@
from typing import Optional
import torch.nn as nn
from transformers import BloomConfig, BloomForCausalLM, BloomModel
from transformers import BloomConfig, BloomModel
from ..base import RewardModel
@ -17,11 +17,13 @@ class BLOOMRM(RewardModel):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
def __init__(
self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
lora_rank: int = 0,
lora_train_bias: str = "none",
) -> None:
if pretrained is not None:
model = BloomModel.from_pretrained(pretrained)
elif config is not None:

View File

@ -1,3 +1,3 @@
from .chatglm_actor import ChatGLMActor
__all__ = ['ChatGLMActor']
__all__ = ["ChatGLMActor"]

View File

@ -1,11 +1,9 @@
from typing import Optional
import torch
from ..base import Actor
from .configuration_chatglm import ChatGLMConfig
from .modeling_chatglm import ChatGLMForConditionalGeneration
from ..base import Actor
class ChatGLMActor(Actor):
"""
@ -19,10 +17,9 @@ class ChatGLMActor(Actor):
do not support lora for now.
"""
def __init__(self,
pretrained: str = None,
config: Optional[ChatGLMConfig] = None,
checkpoint: bool = False) -> None:
def __init__(
self, pretrained: str = None, config: Optional[ChatGLMConfig] = None, checkpoint: bool = False
) -> None:
if pretrained is not None:
model = ChatGLMForConditionalGeneration.from_pretrained(pretrained)
elif config is not None:
@ -31,4 +28,4 @@ class ChatGLMActor(Actor):
model = ChatGLMForConditionalGeneration(ChatGLMConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank=0, lora_train_bias='none')
super().__init__(model, lora_rank=0, lora_train_bias="none")

View File

@ -2,15 +2,14 @@
This code is copied from https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py
"""
"""Tokenization classes for ChatGLM."""
from typing import List, Optional, Union
import os
from typing import Dict, List, Optional, Union
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.utils import logging, PaddingStrategy
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
from typing import Dict
import sentencepiece as spm
import numpy as np
import sentencepiece as spm
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils_base import BatchEncoding, EncodedInput
from transformers.utils import PaddingStrategy, logging
logger = logging.get_logger(__name__)
@ -52,11 +51,11 @@ class TextTokenizer:
class SPTokenizer:
def __init__(
self,
vocab_file,
num_image_tokens=20000,
max_blank_length=80,
byte_fallback=True,
self,
vocab_file,
num_image_tokens=20000,
max_blank_length=80,
byte_fallback=True,
):
assert vocab_file is not None
self.vocab_file = vocab_file
@ -100,9 +99,7 @@ class SPTokenizer:
text = self._encode_whitespaces(text, max_len=self.max_blank_length)
return text
def encode(
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
) -> List[int]:
def encode(self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True) -> List[int]:
"""
@param text: Text to encode.
@param linebreak: Whether to encode newline (\n) in text.
@ -136,9 +133,7 @@ class SPTokenizer:
text = self.postprocess(text)
return text
def tokenize(
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
) -> List[str]:
def tokenize(self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True) -> List[str]:
"""
@param text: Text to encode.
@param linebreak: Whether to encode newline (\n) in text.
@ -181,20 +176,20 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids", "attention_mask", "position_ids"]
def __init__(
self,
vocab_file,
do_lower_case=False,
remove_space=False,
bos_token='<sop>',
eos_token='<eop>',
end_token='</s>',
mask_token='[MASK]',
gmask_token='[gMASK]',
padding_side="left",
pad_token="<pad>",
unk_token="<unk>",
num_image_tokens=20000,
**kwargs
self,
vocab_file,
do_lower_case=False,
remove_space=False,
bos_token="<sop>",
eos_token="<eop>",
end_token="</s>",
mask_token="[MASK]",
gmask_token="[gMASK]",
padding_side="left",
pad_token="<pad>",
unk_token="<unk>",
num_image_tokens=20000,
**kwargs,
) -> None:
super().__init__(
do_lower_case=do_lower_case,
@ -208,7 +203,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
pad_token=pad_token,
unk_token=unk_token,
num_image_tokens=num_image_tokens,
**kwargs
**kwargs,
)
self.do_lower_case = do_lower_case
@ -243,11 +238,11 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
@property
def vocab_size(self):
""" Returns vocab size """
"""Returns vocab size"""
return self.sp_tokenizer.num_tokens
def get_vocab(self):
""" Returns vocab as a dict """
"""Returns vocab as a dict"""
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
@ -264,7 +259,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
return outputs
def _tokenize(self, text, **kwargs):
""" Returns a tokenized string. """
"""Returns a tokenized string."""
text = self.preprocess_text(text)
seq = self.sp_tokenizer.tokenize(text)
@ -274,11 +269,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens: List[str]) -> str:
return self.sp_tokenizer.decode_tokens(tokens)
def _decode(
self,
token_ids: Union[int, List[int]],
**kwargs
) -> str:
def _decode(self, token_ids: Union[int, List[int]], **kwargs) -> str:
if isinstance(token_ids, int):
token_ids = [token_ids]
if len(token_ids) == 0:
@ -288,7 +279,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
return super()._decode(token_ids, **kwargs)
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
return self.sp_tokenizer[token]
def _convert_id_to_token(self, index):
@ -309,13 +300,11 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
`Tuple(str)`: Paths to the files saved.
"""
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, self.vocab_files_names["vocab_file"]
)
vocab_file = os.path.join(save_directory, self.vocab_files_names["vocab_file"])
else:
vocab_file = save_directory
with open(self.vocab_file, 'rb') as fin:
with open(self.vocab_file, "rb") as fin:
proto_str = fin.read()
with open(vocab_file, "wb") as writer:
@ -324,7 +313,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
return (vocab_file,)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
@ -343,19 +332,19 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
gmask_id = self.sp_tokenizer[self.gmask_token]
eos_id = self.sp_tokenizer[self.eos_token]
self.sp_tokenizer[self.eos_token]
token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]]
if token_ids_1 is not None:
token_ids_0 = token_ids_0 + token_ids_1
return token_ids_0
def _pad(
self,
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
self,
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
@ -421,17 +410,23 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
mask_position = required_input.index(mask_token)
position_ids[context_length:] = mask_position
block_position_ids = np.concatenate(
[np.zeros(context_length, dtype=np.int64),
np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
[
np.zeros(context_length, dtype=np.int64),
np.arange(1, seq_length - context_length + 1, dtype=np.int64),
]
)
encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
if needs_to_be_padded:
difference = max_length - len(required_input)
if "attention_mask" in encoded_inputs:
encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"],
pad_width=[(0, 0), (difference, 0), (difference, 0)],
mode='constant', constant_values=True)
encoded_inputs["attention_mask"] = np.pad(
encoded_inputs["attention_mask"],
pad_width=[(0, 0), (difference, 0), (difference, 0)],
mode="constant",
constant_values=True,
)
if "token_type_ids" in encoded_inputs:
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
"token_type_ids"
@ -439,8 +434,9 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
if "special_tokens_mask" in encoded_inputs:
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
if "position_ids" in encoded_inputs:
encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"],
pad_width=[(0, 0), (difference, 0)])
encoded_inputs["position_ids"] = np.pad(
encoded_inputs["position_ids"], pad_width=[(0, 0), (difference, 0)]
)
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
return encoded_inputs
return encoded_inputs

View File

@ -56,30 +56,29 @@ class ChatGLMConfig(PretrainedConfig):
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
```"""
model_type = "chatglm"
def __init__(
self,
vocab_size=130528,
hidden_size=4096,
num_layers=28,
num_attention_heads=32,
layernorm_epsilon=1e-5,
use_cache=True,
bos_token_id=130004,
eos_token_id=130005,
mask_token_id=130000,
gmask_token_id=130001,
pad_token_id=3,
max_sequence_length=2048,
inner_hidden_size=16384,
position_encoding_2d=True,
quantization_bit=0,
pre_seq_len=None,
prefix_projection=False,
**kwargs
self,
vocab_size=130528,
hidden_size=4096,
num_layers=28,
num_attention_heads=32,
layernorm_epsilon=1e-5,
use_cache=True,
bos_token_id=130004,
eos_token_id=130005,
mask_token_id=130000,
gmask_token_id=130001,
pad_token_id=3,
max_sequence_length=2048,
inner_hidden_size=16384,
position_encoding_2d=True,
quantization_bit=0,
pre_seq_len=None,
prefix_projection=False,
**kwargs,
):
self.num_layers = num_layers
self.vocab_size = vocab_size
@ -99,9 +98,4 @@ class ChatGLMConfig(PretrainedConfig):
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs
)
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

View File

@ -4,41 +4,40 @@ This code is copied from https://huggingface.co/THUDM/chatglm-6b/resolve/main/mo
""" PyTorch ChatGLM model. """
import math
import copy
import math
import os
import warnings
import re
import sys
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm
from torch.nn.utils import skip_init
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
BaseModelOutputWithPastAndCrossAttentions,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
from .configuration_chatglm import ChatGLMConfig
# flags required to enable jit fusion kernels
if sys.platform != 'darwin':
if sys.platform != "darwin":
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
@ -93,8 +92,8 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
for n in name
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
for n in name
):
logger.info(f"Skipping {'/'.join(name)}")
continue
@ -127,7 +126,7 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
array = np.transpose(array)
try:
assert (
pointer.shape == array.shape
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e:
e.args += (pointer.shape, array.shape)
@ -153,7 +152,7 @@ class PrefixEncoder(torch.nn.Module):
self.trans = torch.nn.Sequential(
torch.nn.Linear(config.hidden_size, config.hidden_size),
torch.nn.Tanh(),
torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2),
)
else:
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
@ -170,8 +169,7 @@ class PrefixEncoder(torch.nn.Module):
@torch.jit.script
def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
(1.0 + 0.044715 * x * x)))
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
def gelu(x):
@ -181,21 +179,22 @@ def gelu(x):
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
super().__init__()
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
inv_freq = inv_freq.half()
self.learnable = learnable
if learnable:
self.inv_freq = torch.nn.Parameter(inv_freq)
self.max_seq_len_cached = None
else:
self.register_buffer('inv_freq', inv_freq)
self.register_buffer("inv_freq", inv_freq)
self.max_seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
self.precision = precision
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
pass
def forward(self, x, seq_dim=1, seq_len=None):
@ -204,7 +203,7 @@ class RotaryEmbedding(torch.nn.Module):
if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
self.max_seq_len_cached = None if self.learnable else seq_len
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
if self.precision == torch.bfloat16:
@ -230,30 +229,31 @@ class RotaryEmbedding(torch.nn.Module):
def rotate_half(x):
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
@torch.jit.script
def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
# position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \
F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)
cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), F.embedding(
position_id, sin.squeeze(1)
).unsqueeze(2)
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
return q, k
def attention_fn(
self,
query_layer,
key_layer,
value_layer,
attention_mask,
hidden_size_per_partition,
layer_id,
layer_past=None,
scaling_attention_score=True,
use_cache=False,
self,
query_layer,
key_layer,
value_layer,
attention_mask,
hidden_size_per_partition,
layer_id,
layer_past=None,
scaling_attention_score=True,
use_cache=False,
):
if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1]
@ -285,7 +285,9 @@ def attention_fn(
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
matmul_result = torch.zeros(
1, 1, 1,
1,
1,
1,
dtype=query_layer.dtype,
device=query_layer.device,
)
@ -355,9 +357,17 @@ def default_init(cls, *args, **kwargs):
class SelfAttention(torch.nn.Module):
def __init__(self, hidden_size, num_attention_heads,
layer_id, hidden_size_per_attention_head=None, bias=True,
params_dtype=torch.float, position_encoding_2d=True, empty_init=True):
def __init__(
self,
hidden_size,
num_attention_heads,
layer_id,
hidden_size_per_attention_head=None,
bias=True,
params_dtype=torch.float,
position_encoding_2d=True,
empty_init=True,
):
if empty_init:
init_method = skip_init
else:
@ -410,8 +420,7 @@ class SelfAttention(torch.nn.Module):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def split_tensor_along_last_dim(self, tensor, num_partitions,
contiguous_split_chunks=False):
def split_tensor_along_last_dim(self, tensor, num_partitions, contiguous_split_chunks=False):
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
@ -431,14 +440,14 @@ class SelfAttention(torch.nn.Module):
return tensor_list
def forward(
self,
hidden_states: torch.Tensor,
position_ids,
attention_mask: torch.Tensor,
layer_id,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_attentions: bool = False,
self,
hidden_states: torch.Tensor,
position_ids,
attention_mask: torch.Tensor,
layer_id,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
"""
hidden_states: [seq_len, batch, hidden_size]
@ -462,8 +471,10 @@ class SelfAttention(torch.nn.Module):
q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))
k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)
position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \
position_ids[:, 1, :].transpose(0, 1).contiguous()
position_ids, block_position_ids = (
position_ids[:, 0, :].transpose(0, 1).contiguous(),
position_ids[:, 1, :].transpose(0, 1).contiguous(),
)
q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids)
q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids)
query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
@ -484,7 +495,7 @@ class SelfAttention(torch.nn.Module):
hidden_size_per_partition=self.hidden_size_per_partition,
layer_id=layer_id,
layer_past=layer_past,
use_cache=use_cache
use_cache=use_cache,
)
output = self.dense(context_layer)
@ -509,8 +520,16 @@ class GEGLU(torch.nn.Module):
class GLU(torch.nn.Module):
def __init__(self, hidden_size, inner_hidden_size=None,
layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True):
def __init__(
self,
hidden_size,
inner_hidden_size=None,
layer_id=None,
bias=True,
activation_func=gelu,
params_dtype=torch.float,
empty_init=True,
):
super(GLU, self).__init__()
if empty_init:
init_method = skip_init
@ -557,19 +576,19 @@ class GLU(torch.nn.Module):
class GLMBlock(torch.nn.Module):
def __init__(
self,
hidden_size,
num_attention_heads,
layernorm_epsilon,
layer_id,
inner_hidden_size=None,
hidden_size_per_attention_head=None,
layernorm=LayerNorm,
use_bias=True,
params_dtype=torch.float,
num_layers=28,
position_encoding_2d=True,
empty_init=True
self,
hidden_size,
num_attention_heads,
layernorm_epsilon,
layer_id,
inner_hidden_size=None,
hidden_size_per_attention_head=None,
layernorm=LayerNorm,
use_bias=True,
params_dtype=torch.float,
num_layers=28,
position_encoding_2d=True,
empty_init=True,
):
super(GLMBlock, self).__init__()
# Set output layer initialization if not provided.
@ -590,7 +609,7 @@ class GLMBlock(torch.nn.Module):
bias=use_bias,
params_dtype=params_dtype,
position_encoding_2d=self.position_encoding_2d,
empty_init=empty_init
empty_init=empty_init,
)
# Layernorm on the input data.
@ -605,18 +624,18 @@ class GLMBlock(torch.nn.Module):
bias=use_bias,
layer_id=layer_id,
params_dtype=params_dtype,
empty_init=empty_init
empty_init=empty_init,
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids,
attention_mask: torch.Tensor,
layer_id,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_attentions: bool = False,
self,
hidden_states: torch.Tensor,
position_ids,
attention_mask: torch.Tensor,
layer_id,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
"""
hidden_states: [seq_len, batch, hidden_size]
@ -635,7 +654,7 @@ class GLMBlock(torch.nn.Module):
layer_id=layer_id,
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions
output_attentions=output_attentions,
)
attention_output = attention_outputs[0]
@ -702,10 +721,15 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
for i, context_length in enumerate(context_lengths):
position_ids[i, context_length:] = mask_positions[i]
block_position_ids = [torch.cat((
torch.zeros(context_length, dtype=torch.long, device=device),
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
)) for context_length in context_lengths]
block_position_ids = [
torch.cat(
(
torch.zeros(context_length, dtype=torch.long, device=device),
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1,
)
)
for context_length in context_lengths
]
block_position_ids = torch.stack(block_position_ids, dim=0)
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
else:
@ -823,9 +847,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
self.prefix_projection = config.prefix_projection
self.word_embeddings = init_method(
torch.nn.Embedding,
num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
dtype=self.params_dtype
torch.nn.Embedding, num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, dtype=self.params_dtype
)
self.gradient_checkpointing = False
@ -841,12 +863,10 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
use_bias=True,
params_dtype=self.params_dtype,
position_encoding_2d=self.position_encoding_2d,
empty_init=empty_init
empty_init=empty_init,
)
self.layers = torch.nn.ModuleList(
[get_layer(layer_id) for layer_id in range(self.num_layers)]
)
self.layers = torch.nn.ModuleList([get_layer(layer_id) for layer_id in range(self.num_layers)])
# Final layer norm before output.
self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon)
@ -876,7 +896,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
self.pre_seq_len,
self.num_layers * 2,
self.num_attention_heads,
self.hidden_size // self.num_attention_heads
self.hidden_size // self.num_attention_heads,
)
# seq_len, b, nh, hidden_size
past_key_values = self.dropout(past_key_values)
@ -891,18 +911,17 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@ -931,17 +950,14 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
if past_key_values is None:
if self.pre_seq_len is not None:
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device,
dtype=inputs_embeds.dtype)
past_key_values = self.get_prompt(
batch_size=input_ids.shape[0], device=input_ids.device, dtype=inputs_embeds.dtype
)
else:
past_key_values = tuple([None] * len(self.layers))
if attention_mask is None:
attention_mask = self.get_masks(
input_ids,
device=input_ids.device
)
attention_mask = self.get_masks(input_ids, device=input_ids.device)
if position_ids is None:
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
@ -955,15 +971,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
use_gmasks.append(use_gmask)
position_ids = self.get_position_ids(
input_ids,
mask_positions=mask_positions,
device=input_ids.device,
use_gmasks=use_gmasks
input_ids, mask_positions=mask_positions, device=input_ids.device, use_gmasks=use_gmasks
)
if self.pre_seq_len is not None and attention_mask is not None:
prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
attention_mask.device)
attention_mask.device
)
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
@ -980,7 +994,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
attention_mask = attention_mask.to(hidden_states.device)
for i, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_past = past_key_values[i]
@ -994,7 +1007,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
torch.tensor(i),
layer_past,
use_cache,
output_attentions
output_attentions,
)
else:
layer_ret = layer(
@ -1004,7 +1017,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
layer_id=torch.tensor(i),
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions
output_attentions=output_attentions,
)
hidden_states = layer_ret[0]
@ -1049,13 +1062,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
self.transformer = ChatGLMModel(config, empty_init=empty_init)
self.lm_head = init_method(
nn.Linear,
config.hidden_size,
config.vocab_size,
bias=False,
dtype=torch.half
)
self.lm_head = init_method(nn.Linear, config.hidden_size, config.vocab_size, bias=False, dtype=torch.half)
self.config = config
@ -1087,32 +1094,29 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
attention_mask = model_kwargs["attention_mask"]
if attention_mask is not None and attention_mask.dtype == torch.bool:
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3)
[attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3
)
new_attention_mask = attention_mask[:, :, -1:].clone()
new_attention_mask[..., -1] = False
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, new_attention_mask], dim=2
)
model_kwargs["attention_mask"] = torch.cat([attention_mask, new_attention_mask], dim=2)
# update position ids
if "position_ids" in model_kwargs:
position_ids = model_kwargs["position_ids"]
new_position_id = position_ids[..., -1:].clone()
new_position_id[:, 1, :] += 1
model_kwargs["position_ids"] = torch.cat(
[position_ids, new_position_id], dim=-1
)
model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1)
return model_kwargs
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past: Optional[torch.Tensor] = None,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs
self,
input_ids: torch.LongTensor,
past: Optional[torch.Tensor] = None,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
batch_size, seq_length = input_ids.shape
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
@ -1137,11 +1141,17 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
if self.position_encoding_2d:
position_ids = torch.tensor(
[[mask_position, seq_length - context_length] for mask_position, context_length in
zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
[
[mask_position, seq_length - context_length]
for mask_position, context_length in zip(mask_positions, context_lengths)
],
dtype=torch.long,
device=input_ids.device,
).unsqueeze(-1)
else:
position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
device=input_ids.device).unsqueeze(-1)
position_ids = torch.tensor(
[mask_position for mask_position in mask_positions], dtype=torch.long, device=input_ids.device
).unsqueeze(-1)
if past is None:
past = past_key_values
@ -1149,44 +1159,38 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
"input_ids": last_token,
"past_key_values": past,
"position_ids": position_ids,
"attention_mask": attention_mask
"attention_mask": attention_mask,
}
else:
if attention_mask is not None and attention_mask.dtype != torch.bool:
logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool")
attention_mask = None
if attention_mask is None:
attention_mask = self.get_masks(
input_ids,
device=input_ids.device
)
attention_mask = self.get_masks(input_ids, device=input_ids.device)
if position_ids is None:
position_ids = self.get_position_ids(
input_ids,
device=input_ids.device,
mask_positions=mask_positions,
use_gmasks=use_gmasks
input_ids, device=input_ids.device, mask_positions=mask_positions, use_gmasks=use_gmasks
)
return {
"input_ids": input_ids,
"past_key_values": past,
"position_ids": position_ids,
"attention_mask": attention_mask
"attention_mask": attention_mask,
}
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
self,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@ -1235,7 +1239,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
@staticmethod
def _reorder_cache(
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
@ -1268,15 +1272,33 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
return response
@torch.no_grad()
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
def chat(
self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = None,
max_length: int = 2048,
num_beams=1,
do_sample=True,
top_p=0.7,
temperature=0.95,
logits_processor=None,
**kwargs,
):
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
gen_kwargs = {
"max_length": max_length,
"num_beams": num_beams,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
"logits_processor": logits_processor,
**kwargs,
}
if not history:
prompt = query
else:
@ -1287,22 +1309,38 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
inputs = tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.device)
outputs = self.generate(**inputs, **gen_kwargs)
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
response = tokenizer.decode(outputs)
response = self.process_response(response)
history = history + [(query, response)]
return response, history
@torch.no_grad()
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
def stream_chat(
self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = None,
max_length: int = 2048,
do_sample=True,
top_p=0.7,
temperature=0.95,
logits_processor=None,
**kwargs,
):
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
gen_kwargs = {
"max_length": max_length,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
"logits_processor": logits_processor,
**kwargs,
}
if not history:
prompt = query
else:
@ -1313,7 +1351,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
inputs = tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.device)
for outputs in self.stream_generate(**inputs, **gen_kwargs):
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
response = tokenizer.decode(outputs)
response = self.process_response(response)
new_history = history + [(query, response)]
@ -1321,13 +1359,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
@torch.no_grad()
def stream_generate(
self,
input_ids,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
**kwargs,
self,
input_ids,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
**kwargs,
):
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]

View File

@ -16,9 +16,9 @@ except ImportError:
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
def _prepare_logits_processor(top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None) -> LogitsProcessorList:
def _prepare_logits_processor(
top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
if temperature is not None and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
@ -37,18 +37,20 @@ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
return unfinished_sequences.max() == 0
def _sample(model: Actor,
input_ids: torch.Tensor,
max_length: int,
early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs) -> torch.Tensor:
def _sample(
model: Actor,
input_ids: torch.Tensor,
max_length: int,
early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs,
) -> torch.Tensor:
if input_ids.size(1) >= max_length:
return input_ids
@ -56,11 +58,12 @@ def _sample(model: Actor,
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
for _ in range(input_ids.size(1), max_length):
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) \
if prepare_inputs_fn is not None else {'input_ids': input_ids}
model_inputs = (
prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids}
)
outputs = model(**model_inputs)
next_token_logits = outputs['logits'][:, -1, :]
next_token_logits = outputs["logits"][:, -1, :]
# pre-process distribution
next_token_logits = logits_processor(input_ids, next_token_logits)
# sample
@ -90,20 +93,22 @@ def _sample(model: Actor,
@torch.no_grad()
def generate(model: Actor,
input_ids: torch.Tensor,
max_length: int,
num_beams: int = 1,
do_sample: bool = True,
early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs) -> torch.Tensor:
def generate(
model: Actor,
input_ids: torch.Tensor,
max_length: int,
num_beams: int = 1,
do_sample: bool = True,
early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
**model_kwargs,
) -> torch.Tensor:
"""Generate token sequence. The returned sequence is input_ids + generated_tokens.
Args:
@ -121,26 +126,28 @@ def generate(model: Actor,
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
"""
is_greedy_gen_mode = ((num_beams == 1) and do_sample is False)
is_sample_gen_mode = ((num_beams == 1) and do_sample is True)
is_beam_gen_mode = ((num_beams > 1) and do_sample is False)
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
is_sample_gen_mode = (num_beams == 1) and do_sample is True
is_beam_gen_mode = (num_beams > 1) and do_sample is False
if is_greedy_gen_mode:
# run greedy search
raise NotImplementedError
elif is_sample_gen_mode:
# run sample
return _sample(model,
input_ids,
max_length,
early_stopping=early_stopping,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
top_k=top_k,
top_p=top_p,
temperature=temperature,
prepare_inputs_fn=prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
**model_kwargs)
return _sample(
model,
input_ids,
max_length,
early_stopping=early_stopping,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
top_k=top_k,
top_p=top_p,
temperature=temperature,
prepare_inputs_fn=prepare_inputs_fn,
update_model_kwargs_fn=update_model_kwargs_fn,
**model_kwargs,
)
elif is_beam_gen_mode:
raise NotImplementedError
else:

View File

@ -2,4 +2,4 @@ from .gpt_actor import GPTActor
from .gpt_critic import GPTCritic
from .gpt_rm import GPTRM
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM']
__all__ = ["GPTActor", "GPTCritic", "GPTRM"]

View File

@ -18,13 +18,15 @@ class GPTActor(Actor):
lora_train_bias (str): Bias training strategy for the LoRa layer.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
def __init__(
self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = "none",
**kwargs,
) -> None:
if pretrained is not None:
model = GPT2LMHeadModel.from_pretrained(pretrained)
elif config is not None:

View File

@ -18,12 +18,14 @@ class GPTCritic(Critic):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
def __init__(
self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
lora_rank: int = 0,
lora_train_bias: str = "none",
**kwargs,
) -> None:
if pretrained is not None:
model = GPT2Model.from_pretrained(pretrained)
elif config is not None:

View File

@ -18,11 +18,13 @@ class GPTRM(RewardModel):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
def __init__(
self,
pretrained: Optional[str] = None,
config: Optional[GPT2Config] = None,
lora_rank: int = 0,
lora_train_bias: str = "none",
) -> None:
if pretrained is not None:
model = GPT2Model.from_pretrained(pretrained)
elif config is not None:

View File

@ -2,4 +2,4 @@ from .llama_actor import LlamaActor
from .llama_critic import LlamaCritic
from .llama_rm import LlamaRM
__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM']
__all__ = ["LlamaActor", "LlamaCritic", "LlamaRM"]

View File

@ -1,7 +1,6 @@
from typing import Optional
import torch
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
from transformers import LlamaConfig, LlamaForCausalLM
from ..base import Actor
@ -18,13 +17,14 @@ class LlamaActor(Actor):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
def __init__(
self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = "none",
) -> None:
if pretrained is not None:
model = LlamaForCausalLM.from_pretrained(pretrained)
elif config is not None:

View File

@ -17,13 +17,14 @@ class LlamaCritic(Critic):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
def __init__(
self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
lora_rank: int = 0,
lora_train_bias: str = "none",
**kwargs,
) -> None:
if pretrained is not None:
model = LlamaModel.from_pretrained(pretrained)
elif config is not None:

View File

@ -1,7 +1,7 @@
from typing import Optional
import torch.nn as nn
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
from transformers import LlamaConfig, LlamaModel
from ..base import RewardModel
@ -17,12 +17,13 @@ class LlamaRM(RewardModel):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
def __init__(
self,
pretrained: Optional[str] = None,
config: Optional[LlamaConfig] = None,
lora_rank: int = 0,
lora_train_bias: str = "none",
) -> None:
if pretrained is not None:
model = LlamaModel.from_pretrained(pretrained)
elif config is not None:

View File

@ -8,8 +8,7 @@ import torch.nn.functional as F
class LoraLinear(lora.LoRALayer, nn.Module):
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.
"""
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
def __init__(
self,
@ -17,16 +16,14 @@ class LoraLinear(lora.LoRALayer, nn.Module):
bias: Optional[nn.Parameter],
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = True,
):
nn.Module.__init__(self)
lora.LoRALayer.__init__(self,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
merge_weights=merge_weights)
lora.LoRALayer.__init__(
self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
)
self.weight = weight
self.bias = bias
@ -47,13 +44,12 @@ class LoraLinear(lora.LoRALayer, nn.Module):
self.weight.data = self.weight.data.T
def reset_parameters(self):
if hasattr(self, 'lora_A'):
if hasattr(self, "lora_A"):
# Initialize A with the default values for nn.Linear and set B to zero.
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
def T(w):
return w.T if self.fan_in_fan_out else w
@ -71,7 +67,6 @@ class LoraLinear(lora.LoRALayer, nn.Module):
self.merged = False
def eval(self):
def T(w):
return w.T if self.fan_in_fan_out else w
@ -80,12 +75,11 @@ class LoraLinear(lora.LoRALayer, nn.Module):
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
delattr(self, 'lora_A')
delattr(self, 'lora_B')
delattr(self, "lora_A")
delattr(self, "lora_B")
self.merged = True
def forward(self, x: torch.Tensor):
def T(w):
return w.T if self.fan_in_fan_out else w
@ -99,7 +93,9 @@ class LoraLinear(lora.LoRALayer, nn.Module):
def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
assert (
lora_rank <= linear.in_features
), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
return lora_linear
@ -112,7 +108,7 @@ def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
_convert_to_lora_recursively(child, lora_rank)
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = 'none') -> nn.Module:
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = "none") -> nn.Module:
"""Convert a torch.nn.Module to a LoRA module.
Args:
@ -140,7 +136,7 @@ class LoRAModule(nn.Module):
Defaults to 'none'.
"""
def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
def __init__(self, lora_rank: int = 0, lora_train_bias: str = "none") -> None:
super().__init__()
self.lora_rank = lora_rank
self.lora_train_bias = lora_train_bias

View File

@ -31,11 +31,13 @@ class PolicyLoss(nn.Module):
super().__init__()
self.clip_eps = clip_eps
def forward(self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(
self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
ratio = (log_probs - old_log_probs).exp()
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
@ -55,14 +57,16 @@ class ValueLoss(nn.Module):
super().__init__()
self.clip_eps = clip_eps
def forward(self,
values: torch.Tensor,
old_values: torch.Tensor,
reward: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(
self,
values: torch.Tensor,
old_values: torch.Tensor,
reward: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
surr1 = (values_clipped - reward)**2
surr2 = (values - reward)**2
surr1 = (values_clipped - reward) ** 2
surr2 = (values - reward) ** 2
loss = torch.max(surr1, surr2)
loss = loss.mean()
return 0.5 * loss

View File

@ -2,4 +2,4 @@ from .opt_actor import OPTActor
from .opt_critic import OPTCritic
from .opt_rm import OPTRM
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM']
__all__ = ["OPTActor", "OPTCritic", "OPTRM"]

View File

@ -18,12 +18,14 @@ class OPTActor(Actor):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
def __init__(
self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
checkpoint: bool = False,
lora_rank: int = 0,
lora_train_bias: str = "none",
) -> None:
if pretrained is not None:
model = OPTForCausalLM.from_pretrained(pretrained)
elif config is not None:

View File

@ -18,12 +18,14 @@ class OPTCritic(Critic):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
lora_rank: int = 0,
lora_train_bias: str = 'none',
**kwargs) -> None:
def __init__(
self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
lora_rank: int = 0,
lora_train_bias: str = "none",
**kwargs,
) -> None:
if pretrained is not None:
model = OPTModel.from_pretrained(pretrained)
elif config is not None:

View File

@ -17,11 +17,13 @@ class OPTRM(RewardModel):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
lora_rank: int = 0,
lora_train_bias: str = 'none') -> None:
def __init__(
self,
pretrained: Optional[str] = None,
config: Optional[OPTConfig] = None,
lora_rank: int = 0,
lora_train_bias: str = "none",
) -> None:
if pretrained is not None:
model = OPTModel.from_pretrained(pretrained)
elif config is not None:

View File

@ -4,9 +4,9 @@ import torch
import torch.nn.functional as F
def _compute_approx_kl(log_probs: torch.Tensor,
log_probs_base: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def _compute_approx_kl(
log_probs: torch.Tensor, log_probs_base: torch.Tensor, action_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Compute the approximate KL divergence between two distributions.
Schulman blog: http://joschu.net/blog/kl-approx.html
@ -26,11 +26,13 @@ def _compute_approx_kl(log_probs: torch.Tensor,
return approx_kl
def compute_reward(r: Union[torch.Tensor, float],
kl_coef: float,
log_probs: torch.Tensor,
log_probs_base: torch.Tensor,
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def compute_reward(
r: Union[torch.Tensor, float],
kl_coef: float,
log_probs: torch.Tensor,
log_probs_base: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if kl_coef <= 0.0:
return r
kl = _compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
@ -55,7 +57,7 @@ def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num
Returns:
torch.Tensor: Action log probs.
"""
logits = output['logits']
logits = output["logits"]
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]

View File

@ -2,6 +2,6 @@ from .llama_gptq import load_quant as llama_load_quant
from .utils import low_resource_init
__all__ = [
'llama_load_quant',
'low_resource_init',
"llama_load_quant",
"low_resource_init",
]

View File

@ -1,5 +1,5 @@
from .loader import load_quant
__all__ = [
'load_quant',
"load_quant",
]

View File

@ -11,14 +11,15 @@ def load_quant(model: nn.Module, checkpoint: str, wbits: int, groupsize: int):
# ignore lm head
layers = find_layers(model)
for name in ['lm_head']:
for name in ["lm_head"]:
if name in layers:
del layers[name]
make_quant(model, layers, wbits, groupsize)
if checkpoint.endswith('.safetensors'):
if checkpoint.endswith(".safetensors"):
from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint))
else:
model.load_state_dict(torch.load(checkpoint))

View File

@ -1,13 +1,12 @@
# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py
import torch
import torch.nn as nn
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1))
return res

View File

@ -13,14 +13,13 @@ def quantize(x, scale, zero, maxq):
class Quantizer(nn.Module):
def __init__(self, shape=1):
super(Quantizer, self).__init__()
self.register_buffer('maxq', torch.tensor(0))
self.register_buffer('scale', torch.zeros(shape))
self.register_buffer('zero', torch.zeros(shape))
self.register_buffer("maxq", torch.tensor(0))
self.register_buffer("scale", torch.zeros(shape))
self.register_buffer("zero", torch.zeros(shape))
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8):
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=0.8):
self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel
self.sym = sym
@ -68,7 +67,7 @@ class Quantizer(nn.Module):
self.zero = torch.round(-xmin / self.scale)
if self.mse:
best = torch.full([x.shape[0]], float('inf'), device=dev)
best = torch.full([x.shape[0]], float("inf"), device=dev)
for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid
xmin1 = p * xmin
@ -123,13 +122,12 @@ class Quantizer(nn.Module):
try:
import quant_cuda
except:
print('CUDA extension not installed.')
print("CUDA extension not installed.")
# Assumes layer is perfectly divisible into 256 * 256 blocks
class QuantLinear(nn.Module):
def __init__(self, bits, groupsize, infeatures, outfeatures):
super().__init__()
if bits not in [2, 3, 4, 8]:
@ -142,11 +140,11 @@ class QuantLinear(nn.Module):
groupsize = groupsize if groupsize != -1 else infeatures
self.groupsize = groupsize
self.register_buffer(
'qzeros', torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)),
dtype=torch.int))
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
self.register_buffer('bias', torch.zeros(outfeatures))
self.register_buffer('qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
"qzeros", torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)), dtype=torch.int)
)
self.register_buffer("scales", torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
self.register_buffer("bias", torch.zeros(outfeatures))
self.register_buffer("qweight", torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
self._initialized_quant_state = False
def pack(self, linear, scales, zeros):
@ -161,8 +159,10 @@ class QuantLinear(nn.Module):
for idx in range(self.infeatures):
g_idx = idx // self.groupsize
intweight.append(
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:,
None])
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[
:, None
]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
@ -271,13 +271,13 @@ class QuantLinear(nn.Module):
return y.reshape(outshape)
def make_quant(module, names, bits, groupsize, name=''):
def make_quant(module, names, bits, groupsize, name=""):
if isinstance(module, QuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr
name1 = name + "." + attr if name != "" else attr
if name1 in names:
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features))
for name1, child in module.named_children():
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
make_quant(child, names, bits, groupsize, name + "." + name1 if name != "" else name1)

View File

@ -9,8 +9,7 @@ def _noop(*args, **kwargs):
@contextmanager
def low_resource_init():
"""This context manager disables weight initialization and sets the default float dtype to half.
"""
"""This context manager disables weight initialization and sets the default float dtype to half."""
old_kaiming_uniform_ = torch.nn.init.kaiming_uniform_
old_uniform_ = torch.nn.init.uniform_
old_normal_ = torch.nn.init.normal_

View File

@ -5,7 +5,7 @@ from coati.experience_maker import Experience
class TrainerCallback(ABC):
"""
Base callback class. It defines the interface for callbacks.
Base callback class. It defines the interface for callbacks.
"""
def on_fit_start(self) -> None:
@ -40,7 +40,6 @@ class TrainerCallback(ABC):
class MakerCallback(ABC):
def on_loop_start(self) -> None:
pass

View File

@ -30,10 +30,9 @@ def all_reduce_mean(x: float, world_size: int) -> float:
class Timer:
def __init__(self) -> None:
self.start_time: Optional[float] = None
self.duration: float = 0.
self.duration: float = 0.0
def start(self) -> None:
self.start_time = time()
@ -42,13 +41,13 @@ class Timer:
self.duration += time() - self.start_time
def reset(self) -> None:
self.duration = 0.
self.duration = 0.0
class ExperienceMakerPerformanceEvaluator(MakerCallback):
def __init__(self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int,
reward_model_num_params: int) -> None:
def __init__(
self, actor_num_params: int, critic_num_params: int, initial_model_num_params: int, reward_model_num_params: int
) -> None:
super().__init__()
self.world_size = get_world_size()
self.actor_num_params = actor_num_params
@ -63,7 +62,7 @@ class ExperienceMakerPerformanceEvaluator(MakerCallback):
self.make_experience_flop: int = 0
print_rank_0(
f'ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}'
f"ExperienceMaker actor: {actor_num_params/1024**3:.2f}B, critic: {critic_num_params/1024**3:.2f}B, initial model: {initial_model_num_params/1024**3:.2f}B, reward model: {reward_model_num_params/1024**3:.2f}B, world size: {self.world_size}"
)
def on_make_experience_start(self) -> None:
@ -110,27 +109,29 @@ class ExperienceMakerPerformanceEvaluator(MakerCallback):
avg_throughput = self.total_samples * self.world_size / (avg_overall_duration + 1e-12)
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
avg_time_per_sample = (avg_overall_duration + 1e-12) / (self.total_samples * self.world_size)
avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / \
(self.total_samples * self.world_size)
avg_make_experience_time_per_sample = (avg_make_experience_duration + 1e-12) / (
self.total_samples * self.world_size
)
avg_send_time_per_sample = (avg_send_duration + 1e-12) / (self.total_samples * self.world_size)
print_rank_0(
'Making Experience Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
+ f'TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n'
+ f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
+ f'Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+ f'Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n'
"Making Experience Performance Summary:\n"
+ f"Throughput: {avg_throughput:.3f} samples/sec\n"
+ f"TFLOPS per GPU: {avg_make_experience_tflops:.3f}\n"
+ f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
+ f"Sample time (make experience): {avg_make_experience_time_per_sample:.3f} s, {avg_make_experience_time_per_sample/avg_time_per_sample*100:.2f}%\n"
+ f"Sample time (send): {avg_send_time_per_sample:.3f} s, {avg_send_time_per_sample/avg_time_per_sample*100:.2f}%\n"
)
class TrainerPerformanceEvaluator(TrainerCallback):
def __init__(self,
actor_num_params: int,
critic_num_params: int,
enable_grad_checkpoint: bool = False,
ignore_first_episodes: int = 1) -> None:
def __init__(
self,
actor_num_params: int,
critic_num_params: int,
enable_grad_checkpoint: bool = False,
ignore_first_episodes: int = 1,
) -> None:
super().__init__()
self.world_size = get_world_size()
self.actor_num_params = actor_num_params
@ -146,7 +147,7 @@ class TrainerPerformanceEvaluator(TrainerCallback):
self.learn_flop: int = 0
print_rank_0(
f'Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}'
f"Trainer actor: {self.actor_num_params/1024**3:.2f}B, critic: {self.critic_num_params/1024**3:.2f}B, world size: {self.world_size}"
)
def on_episode_start(self, episodes: int) -> None:
@ -191,7 +192,7 @@ class TrainerPerformanceEvaluator(TrainerCallback):
def on_fit_end(self) -> None:
if self.total_samples == 0:
print_rank_0('No samples are collected, skip trainer performance evaluation')
print_rank_0("No samples are collected, skip trainer performance evaluation")
return
avg_train_duration = all_reduce_mean(self.batch_timer.duration, self.world_size)
avg_update_duration = all_reduce_mean(self.update_timer.duration, self.world_size)
@ -204,9 +205,10 @@ class TrainerPerformanceEvaluator(TrainerCallback):
avg_update_time_per_sample = (avg_update_duration + 1e-12) / (self.total_samples * self.world_size)
print_rank_0(
'Learning Performance Summary:\n' + f'Throughput: {avg_throughput:.3f} samples/sec\n'
+ f'TFLOPS per GPU: {avg_learn_tflops:.3f}\n' + f'Sample time (overall): {avg_time_per_sample:.3f} s\n'
+ f'Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n'
+ f'Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n'
"Learning Performance Summary:\n"
+ f"Throughput: {avg_throughput:.3f} samples/sec\n"
+ f"TFLOPS per GPU: {avg_learn_tflops:.3f}\n"
+ f"Sample time (overall): {avg_time_per_sample:.3f} s\n"
+ f"Sample time (train): {avg_train_time_per_sample:.3f} s, {avg_train_time_per_sample/avg_time_per_sample*100:.2f}%\n"
+ f"Sample time (update): {avg_update_time_per_sample:.3f} s, {avg_update_time_per_sample/avg_time_per_sample*100:.2f}%\n"
)

View File

@ -1,20 +1,15 @@
import asyncio
import copy
import random
from threading import Lock
from typing import Any, List
from typing import List
import ray
import torch
from coati.experience_buffer import ExperienceBuffer
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.experience_maker.base import Experience
# from torch.multiprocessing import Queue
from ray.util.queue import Queue
class DetachedReplayBuffer:
'''
"""
Detached replay buffer. Share Experience across workers on the same node.
Therefore, a trainer node is expected to have only one instance.
It is ExperienceMakerHolder's duty to call append(exp) method, remotely.
@ -24,7 +19,7 @@ class DetachedReplayBuffer:
tp_world_size: Number of workers in the same tp group
limit: Limit of number of experience sample BATCHs. A number <= 0 means unlimited. Defaults to 0.
cpu_offload: Whether to offload experience to cpu when sampling. Defaults to True.
'''
"""
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
self.sample_batch_size = sample_batch_size
@ -34,23 +29,23 @@ class DetachedReplayBuffer:
@torch.no_grad()
def append(self, experience: Experience) -> None:
'''
"""
Expected to be called remotely.
'''
"""
items = split_experience_batch(experience)
self.extend(items)
@torch.no_grad()
def extend(self, items: List[BufferItem]) -> None:
'''
"""
Expected to be called remotely.
'''
"""
self.batch_collector.extend(items)
while len(self.batch_collector) >= self.sample_batch_size:
items = self.batch_collector[:self.sample_batch_size]
items = self.batch_collector[: self.sample_batch_size]
experience = make_experience_batch(items)
self.items.put(experience, block=True)
self.batch_collector = self.batch_collector[self.sample_batch_size:]
self.batch_collector = self.batch_collector[self.sample_batch_size :]
def clear(self) -> None:
# self.items.close()

View File

@ -1,6 +1,6 @@
import os
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, List
import ray
import torch
@ -15,7 +15,7 @@ from .utils import is_rank_0
class DetachedTrainer(ABC):
'''
"""
Base class for detached rlhf trainers.
'detach' means that the experience maker is detached compared to a normal Trainer.
Please set name attribute during init:
@ -28,15 +28,17 @@ class DetachedTrainer(ABC):
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
'''
"""
def __init__(self,
experience_maker_holder_name_list: List[str],
train_batch_size: int = 8,
buffer_limit: int = 0,
dataloader_pin_memory: bool = True,
callbacks: List[TrainerCallback] = [],
debug: bool = False) -> None:
def __init__(
self,
experience_maker_holder_name_list: List[str],
train_batch_size: int = 8,
buffer_limit: int = 0,
dataloader_pin_memory: bool = True,
callbacks: List[TrainerCallback] = [],
debug: bool = False,
) -> None:
super().__init__()
self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit)
self.dataloader_pin_memory = dataloader_pin_memory
@ -67,18 +69,16 @@ class DetachedTrainer(ABC):
def _learn(self, update_steps: int, train_epochs: int) -> None:
data = []
# warmup
pbar = tqdm(range(update_steps), desc=f'Train epoch [1/{train_epochs}]', disable=not is_rank_0())
pbar = tqdm(range(update_steps), desc=f"Train epoch [1/{train_epochs}]", disable=not is_rank_0())
self._on_epoch_start(0)
self._learn_epoch(pbar, data)
self._on_epoch_end(0)
# item is already a batch
dataloader = DataLoader(data,
batch_size=1,
shuffle=True,
pin_memory=self.dataloader_pin_memory,
collate_fn=lambda x: x[0])
dataloader = DataLoader(
data, batch_size=1, shuffle=True, pin_memory=self.dataloader_pin_memory, collate_fn=lambda x: x[0]
)
for epoch in range(1, train_epochs):
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch + 1}/{train_epochs}]', disable=not is_rank_0())
pbar = tqdm(dataloader, desc=f"Train epoch [{epoch + 1}/{train_epochs}]", disable=not is_rank_0())
self._on_epoch_start(epoch)
self._learn_epoch(pbar, data)
self._on_epoch_end(epoch)
@ -104,7 +104,7 @@ class DetachedTrainer(ABC):
def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None:
self._on_fit_start()
for i in tqdm(range(total_steps // update_steps), desc='Trainer', disable=not is_rank_0()):
for i in tqdm(range(total_steps // update_steps), desc="Trainer", disable=not is_rank_0()):
self._on_episode_start(i)
self._learn(update_steps, train_epochs)
self._on_update_start()

View File

@ -1,12 +1,11 @@
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Tuple
import ray
import torch
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.experience_maker import Experience
from coati.models.base import Actor, Critic
from coati.models.loss import PolicyLoss, ValueLoss
from coati.trainer.callbacks import Callback
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy, Strategy
from coati.trainer.strategies import GeminiStrategy, LowLevelZeroStrategy, Strategy
from torch.optim import Adam
from colossalai.nn.optimizer import HybridAdam
@ -14,27 +13,14 @@ from colossalai.nn.optimizer import HybridAdam
from .callbacks import TrainerCallback, TrainerPerformanceEvaluator
from .detached_trainer_base import DetachedTrainer
from .lora_constructor import LoRAConstructor
from .utils import (
get_actor_from_args,
get_critic_from_args,
get_model_numel,
get_rank,
get_strategy_from_args,
is_rank_0,
set_dist_env,
state_dict_to,
from .utils import get_model_numel, get_rank, set_dist_env, state_dict_to
@ray.remote(
concurrency_groups={"buffer_length": 1, "buffer_append": 1, "buffer_sample": 1, "model_io": 1, "compute": 1}
)
@ray.remote(concurrency_groups={
"buffer_length": 1,
"buffer_append": 1,
"buffer_sample": 1,
"model_io": 1,
"compute": 1
})
class DetachedPPOTrainer(DetachedTrainer):
'''
"""
Detached Trainer for PPO algorithm
Args:
strategy (Strategy): the strategy to use for training
@ -52,7 +38,7 @@ class DetachedPPOTrainer(DetachedTrainer):
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
'''
"""
def __init__(
self,
@ -92,21 +78,24 @@ class DetachedPPOTrainer(DetachedTrainer):
self.actor_optim = Adam(self.actor.parameters(), lr=1e-7)
self.critic_optim = Adam(self.critic.parameters(), lr=1e-7)
(self.actor, self.actor_optim), (self.critic, self.critic_optim) = \
self.strategy.prepare((self.actor, self.actor_optim), (self.critic, self.critic_optim))
(self.actor, self.actor_optim), (self.critic, self.critic_optim) = self.strategy.prepare(
(self.actor, self.actor_optim), (self.critic, self.critic_optim)
)
# configure trainer
self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip)
super().__init__(experience_maker_holder_name_list,
train_batch_size=train_batch_size,
buffer_limit=buffer_limit,
dataloader_pin_memory=dataloader_pin_memory,
callbacks=callbacks,
debug=debug)
super().__init__(
experience_maker_holder_name_list,
train_batch_size=train_batch_size,
buffer_limit=buffer_limit,
dataloader_pin_memory=dataloader_pin_memory,
callbacks=callbacks,
debug=debug,
)
if self._debug:
print(f'[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}')
print(f"[trainer{get_rank()}] will send state dict to {experience_maker_holder_name_list}")
self._update_lora_weights = update_lora_weights
@ -115,7 +104,7 @@ class DetachedPPOTrainer(DetachedTrainer):
def _update_remote_makers(self, fully_update: bool = False, **config):
# TODO: balance duties
if not fully_update:
config['requires_grad_only'] = True
config["requires_grad_only"] = True
self.update_target_holder_list()
# mark start, ensure order
tasks = []
@ -131,7 +120,9 @@ class DetachedPPOTrainer(DetachedTrainer):
target_holder.update_experience_maker.remote(
new_actor_state_dict=state_dict_shard,
new_actor_lora_config_dict=self._get_model_lora_config_dict(self.actor),
fully_update=fully_update))
fully_update=fully_update,
)
)
# sending loop
for state_dict_shard in self._get_model_state_dict_shard(self.critic, fully_update=fully_update, **config):
for target_holder in self.target_holder_list:
@ -139,7 +130,9 @@ class DetachedPPOTrainer(DetachedTrainer):
target_holder.update_experience_maker.remote(
new_critic_state_dict=state_dict_shard,
new_critic_lora_config_dict=self._get_model_lora_config_dict(self.critic),
fully_update=fully_update))
fully_update=fully_update,
)
)
ray.get(tasks)
# mark end
for target_holder in self.target_holder_list:
@ -152,26 +145,24 @@ class DetachedPPOTrainer(DetachedTrainer):
num_actions = experience.action_mask.size(1)
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
actor_loss = self.actor_loss_fn(action_log_probs,
experience.action_log_probs,
experience.advantages,
action_mask=experience.action_mask)
actor_loss = self.actor_loss_fn(
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
)
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
self.strategy.optimizer_step(self.actor_optim)
self.actor_optim.zero_grad()
values = self.critic(experience.sequences,
action_mask=experience.action_mask,
attention_mask=experience.attention_mask)
critic_loss = self.critic_loss_fn(values,
experience.values,
experience.reward,
action_mask=experience.action_mask)
values = self.critic(
experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
)
critic_loss = self.critic_loss_fn(
values, experience.values, experience.reward, action_mask=experience.action_mask
)
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim)
self.critic_optim.zero_grad()
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
return {"actor_loss": actor_loss.item(), "critic_loss": critic_loss.item()}
def strategy_save_actor(self, path: str, only_rank0: bool = False) -> None:
self.strategy.save_model(self.actor, path, only_rank0)

View File

@ -1,53 +1,49 @@
import os
import time
import tracemalloc
from copy import deepcopy
from threading import Lock
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
import ray
import torch
import torch.nn as nn
from coati.experience_buffer.utils import BufferItem, make_experience_batch, split_experience_batch
from coati.experience_maker import Experience, ExperienceMaker, NaiveExperienceMaker
from coati.experience_buffer.utils import split_experience_batch
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic, RewardModel
from coati.trainer.callbacks import Callback
from coati.trainer.strategies import Strategy
from coati.trainer.strategies.sampler import DistributedSampler
from ray.exceptions import GetTimeoutError
from torch import Tensor
from tqdm import tqdm
from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
from .lora_constructor import LoRAConstructor
from .utils import get_model_numel, get_rank, get_world_size, is_rank_0, set_dist_env, state_dict_to
from .utils import get_model_numel, get_rank, is_rank_0, set_dist_env, state_dict_to
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
class ExperienceMakerHolder:
'''
"""
Args:
detached_trainer_name_list: str list to get ray actor handles
strategy:
kl_coef: the coefficient of kl divergence loss
sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models.
'''
"""
def __init__(
self,
detached_trainer_name_list: List[str],
strategy_fn: Callable[[], Strategy],
self,
detached_trainer_name_list: List[str],
strategy_fn: Callable[[], Strategy],
# a function returns (actor, critic, reward_model, initial_model)
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
env_info: Dict[str, str] = None,
sync_models_from_trainers: bool = False,
buffer_cpu_offload: bool = True,
kl_coef: float = 0.1,
callbacks: List[MakerCallback] = [],
eval_performance: bool = False,
debug: bool = False,
update_lora_weights: bool = False,
**generate_kwargs):
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
env_info: Dict[str, str] = None,
sync_models_from_trainers: bool = False,
buffer_cpu_offload: bool = True,
kl_coef: float = 0.1,
callbacks: List[MakerCallback] = [],
eval_performance: bool = False,
debug: bool = False,
update_lora_weights: bool = False,
**generate_kwargs,
):
# set environment variables
if env_info:
set_dist_env(env_info=env_info)
@ -66,8 +62,9 @@ class ExperienceMakerHolder:
critic_numel = get_model_numel(critic)
initial_model_numel = get_model_numel(initial_model)
reward_model_numel = get_model_numel(reward_model)
evaluator = ExperienceMakerPerformanceEvaluator(actor_numel, critic_numel, initial_model_numel,
reward_model_numel)
evaluator = ExperienceMakerPerformanceEvaluator(
actor_numel, critic_numel, initial_model_numel, reward_model_numel
)
callbacks = callbacks + [evaluator]
actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model)
@ -89,9 +86,9 @@ class ExperienceMakerHolder:
self._target_idx = 0
if self._debug:
print(f'[maker{get_rank()}] will send items to {self._detached_trainer_name_list}')
print(f"[maker{get_rank()}] will send items to {self._detached_trainer_name_list}")
if not self._is_fully_initialized:
print(f'[maker{get_rank()}] Waiting for INIT')
print(f"[maker{get_rank()}] Waiting for INIT")
def _get_ready(self):
while not self._fully_initialized():
@ -136,7 +133,7 @@ class ExperienceMakerHolder:
self._on_make_experience_end(experience)
self._on_send_start()
if self.buffer_cpu_offload:
experience.to_device('cpu')
experience.to_device("cpu")
self._send_items(experience)
self._on_send_end()
self._on_batch_end()
@ -155,7 +152,7 @@ class ExperienceMakerHolder:
if num_steps > 0:
# ignore num epochs
it = iter(dataloader)
for _ in tqdm(range(num_steps), desc='ExperienceMaker', disable=not is_rank_0()):
for _ in tqdm(range(num_steps), desc="ExperienceMaker", disable=not is_rank_0()):
try:
batch = next(it)
except StopIteration:
@ -163,7 +160,7 @@ class ExperienceMakerHolder:
batch = next(it)
self._inference_step(batch)
else:
with tqdm(total=num_epochs * len(dataloader), desc='ExperienceMaker', disable=not is_rank_0()) as pbar:
with tqdm(total=num_epochs * len(dataloader), desc="ExperienceMaker", disable=not is_rank_0()) as pbar:
for _ in range(num_epochs):
for batch in dataloader:
self._inference_step(batch)
@ -171,22 +168,24 @@ class ExperienceMakerHolder:
self._on_loop_end()
@ray.method(concurrency_group="model_io")
def update_experience_maker(self,
new_actor_state_dict: Dict[str, Any] = None,
new_actor_lora_config_dict: Dict[str, Any] = None,
new_critic_state_dict: Dict[str, Any] = None,
new_critic_lora_config_dict: Dict[str, Any] = None,
fully_update: bool = False,
chunk_start: bool = None,
chunk_end: bool = None):
'''
called by trainer
chunk_start: Set True at the first call. Before sending state_dict calls
chunk_end: Set True at the last call. After sending state_dict calls.
fully_update: Set True if you want to sync models when initializing
def update_experience_maker(
self,
new_actor_state_dict: Dict[str, Any] = None,
new_actor_lora_config_dict: Dict[str, Any] = None,
new_critic_state_dict: Dict[str, Any] = None,
new_critic_lora_config_dict: Dict[str, Any] = None,
fully_update: bool = False,
chunk_start: bool = None,
chunk_end: bool = None,
):
"""
called by trainer
chunk_start: Set True at the first call. Before sending state_dict calls
chunk_end: Set True at the last call. After sending state_dict calls.
fully_update: Set True if you want to sync models when initializing
TODO: load_state_dict integrate with model-sharding strategy
'''
TODO: load_state_dict integrate with model-sharding strategy
"""
_watch_memory = self._debug
if chunk_start:
if self._debug:
@ -202,18 +201,22 @@ class ExperienceMakerHolder:
else:
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(
new_actor_state_dict, new_actor_lora_config_dict)
new_actor_state_dict, new_actor_lora_config_dict
)
self.actor_lora_constructor.load_state_dict_increase(
self.experience_maker.actor.model, state_dict_increase)
self.experience_maker.actor.model, state_dict_increase
)
if new_critic_state_dict is not None:
if not self._update_lora_weights or fully_update:
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
else:
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(
new_critic_state_dict, new_critic_lora_config_dict)
new_critic_state_dict, new_critic_lora_config_dict
)
self.critic_lora_constructor.load_state_dict_increase(
self.experience_maker.critic, state_dict_increase)
self.experience_maker.critic, state_dict_increase
)
# the lock must be released after both actor and critic being updated
if chunk_end:
@ -262,10 +265,10 @@ def _set_default_generate_kwargs(generate_kwargs: dict, actor: Actor) -> None:
origin_model = actor.model
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
if "prepare_inputs_fn" not in generate_kwargs and hasattr(origin_model, "prepare_inputs_for_generation"):
new_kwargs["prepare_inputs_fn"] = origin_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'):
new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation
if "update_model_kwargs_fn" not in generate_kwargs and hasattr(origin_model, "_update_model_kwargs_for_generation"):
new_kwargs["update_model_kwargs_fn"] = origin_model._update_model_kwargs_for_generation
return new_kwargs

View File

@ -1,11 +1,9 @@
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict
import torch
import torch.nn as nn
from coati.models.lora import LoraLinear
from loralib.layers import LoRALayer
@dataclass
@ -17,7 +15,7 @@ class LoRAConfig:
class LoRAConstructor:
'''
"""
Tools for reconstructing a model from a remote LoRA model.
(Transferring only LoRA data costs much less!)
Usage:
@ -36,7 +34,7 @@ class LoRAConstructor:
Step 5 (Receiver):
load_state_dict_increase()
'''
"""
def __init__(self):
self.lora_config_dict = None
@ -45,10 +43,10 @@ class LoRAConstructor:
self.lora_config_dict = lora_config_dict
def reconstruct_increase(self, state_dict_lora: Dict[str, Any], lora_config_dict: Dict[str, Any]):
'''
xxx.lora_A, xxx.lora_B -->> xxx.weight
Warning: the xxx.weight here is the increment actually.
'''
"""
xxx.lora_A, xxx.lora_B -->> xxx.weight
Warning: the xxx.weight here is the increment actually.
"""
if lora_config_dict is not None:
self.register_lora_config(lora_config_dict)
@ -56,24 +54,25 @@ class LoRAConstructor:
config_iter = iter(self.lora_config_dict.items())
lora_A, lora_B, layer_prefix = None, None, None
for k, v in state_dict_lora.items():
if k.rpartition('.')[-1] == 'lora_A':
if k.rpartition(".")[-1] == "lora_A":
lora_A = v
layer_prefix = k.rpartition('.')[0]
elif k.rpartition('.')[-1] == 'lora_B':
assert layer_prefix == k.rpartition('.')[0], "unmatched (lora_A, lora_B) pair"
layer_prefix = k.rpartition(".")[0]
elif k.rpartition(".")[-1] == "lora_B":
assert layer_prefix == k.rpartition(".")[0], "unmatched (lora_A, lora_B) pair"
layer_prefix_2, config = next(config_iter)
assert layer_prefix_2 == layer_prefix, "unmatched (state_dict, config_dict) pair"
lora_B = v
weight_data_increase = self._compute(lora_A, lora_B, config)
state_dict_increase[layer_prefix + '.weight'] = weight_data_increase
state_dict_increase[layer_prefix + ".weight"] = weight_data_increase
lora_A, lora_B, layer_prefix = None, None, None
else:
raise ValueError('unexpected key')
raise ValueError("unexpected key")
return state_dict_increase
def _compute(self, lora_A, lora_B, config=LoRAConfig()):
def T(w):
return w.T if config.fan_in_fan_out else w
if config.r > 0:
scaling = config.lora_alpha / config.r
weight_data_increase = T(lora_B @ lora_A) * scaling
@ -81,21 +80,21 @@ class LoRAConstructor:
return 0
def load_state_dict_increase(self, model: nn.Module, state_dict_increase: Dict[str, Any]):
'''
"""
The final reconstruction step
'''
"""
# naive approach
model.load_state_dict({k: v + model.state_dict()[k] for k, v in state_dict_increase.items()}, strict=False)
@staticmethod
def filter_state_dict_lora(state_dict: Dict[str, Any], keep_non_lora=False):
'''
"""
if keep_non_lora, also return non_lora state_dict
'''
"""
state_dict_lora = OrderedDict()
state_dict_non_lora = OrderedDict()
for k, v in state_dict.items():
if 'lora_A' in k or 'lora_B' in k:
if "lora_A" in k or "lora_B" in k:
state_dict_lora[k] = v
elif keep_non_lora:
state_dict_non_lora[k] = v
@ -106,17 +105,19 @@ class LoRAConstructor:
@staticmethod
def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]:
'''
"""
extract LoraLinear model.
return OrderedDict(): name -> LoRAConfig
'''
"""
lora_config_dict = OrderedDict()
for name, child in model.named_modules():
if isinstance(child, LoraLinear):
lora_config_dict[name] = LoRAConfig(r=child.r,
lora_alpha=child.lora_alpha,
lora_dropout=child.lora_dropout,
fan_in_fan_out=child.fan_in_fan_out)
lora_config_dict[name] = LoRAConfig(
r=child.r,
lora_alpha=child.lora_alpha,
lora_dropout=child.lora_dropout,
fan_in_fan_out=child.fan_in_fan_out,
)
return lora_config_dict

View File

@ -1,6 +1,6 @@
import os
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict
import torch
import torch.distributed as dist
@ -10,7 +10,7 @@ from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer
def is_rank_0() -> bool:
@ -26,13 +26,13 @@ def get_world_size() -> int:
def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
if model == 'gpt2':
if model == "gpt2":
actor = GPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
elif model == 'bloom':
elif model == "bloom":
actor = BLOOMActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
elif model == 'opt':
elif model == "opt":
actor = OPTActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
elif model == 'llama':
elif model == "llama":
actor = LlamaActor(pretrained=pretrained, config=config, lora_rank=lora_rank)
else:
raise ValueError(f'Unsupported actor model "{model}"')
@ -40,13 +40,13 @@ def get_actor_from_args(model: str, pretrained: str = None, config=None, lora_ra
def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_rank=0):
if model == 'gpt2':
if model == "gpt2":
critic = GPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
elif model == 'bloom':
elif model == "bloom":
critic = BLOOMCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
elif model == 'opt':
elif model == "opt":
critic = OPTCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
elif model == 'llama':
elif model == "llama":
critic = LlamaCritic(pretrained=pretrained, lora_rank=lora_rank, config=config, use_action_mask=True)
else:
raise ValueError(f'Unsupported reward model "{model}"')
@ -54,13 +54,13 @@ def get_critic_from_args(model: str, pretrained: str = None, config=None, lora_r
def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
if model == 'gpt2':
if model == "gpt2":
reward_model = GPTRM(pretrained=pretrained, config=config)
elif model == 'bloom':
elif model == "bloom":
reward_model = BLOOMRM(pretrained=pretrained, config=config)
elif model == 'opt':
elif model == "opt":
reward_model = OPTRM(pretrained=pretrained, config=config)
elif model == 'llama':
elif model == "llama":
reward_model = LlamaRM(pretrained=pretrained, config=config)
else:
raise ValueError(f'Unsupported reward model "{model}"')
@ -68,29 +68,29 @@ def get_reward_model_from_args(model: str, pretrained: str = None, config=None):
def get_strategy_from_args(strategy: str):
if strategy == 'ddp':
if strategy == "ddp":
strategy_ = DDPStrategy()
elif strategy == 'colossalai_gemini':
strategy_ = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
elif strategy == 'colossalai_zero2':
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
elif strategy == 'colossalai_gemini_cpu':
strategy_ = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
elif strategy == 'colossalai_zero2_cpu':
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
elif strategy == "colossalai_gemini":
strategy_ = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
elif strategy == "colossalai_zero2":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif strategy == "colossalai_gemini_cpu":
strategy_ = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
elif strategy == "colossalai_zero2_cpu":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:
raise ValueError(f'Unsupported strategy "{strategy}"')
return strategy_
def get_tokenizer_from_args(model: str, **kwargs):
if model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
elif model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
elif model == 'opt':
if model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
elif model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
elif model == "opt":
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
elif model == 'llama':
elif model == "llama":
pretrain_path = kwargs["pretrain"]
tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
else:
@ -101,11 +101,11 @@ def get_tokenizer_from_args(model: str, **kwargs):
def set_dist_env(env_info: Dict[str, str]):
os.environ["RANK"] = env_info['rank']
os.environ["LOCAL_RANK"] = env_info['local_rank']
os.environ["WORLD_SIZE"] = env_info['world_size']
os.environ['MASTER_PORT'] = env_info['master_port']
os.environ['MASTER_ADDR'] = env_info['master_addr']
os.environ["RANK"] = env_info["rank"]
os.environ["LOCAL_RANK"] = env_info["local_rank"]
os.environ["WORLD_SIZE"] = env_info["world_size"]
os.environ["MASTER_PORT"] = env_info["master_port"]
os.environ["MASTER_ADDR"] = env_info["master_addr"]
def get_model_numel(model: nn.Module) -> int:
@ -128,12 +128,12 @@ def get_receivers_per_sender(sender_idx: int, num_senders: int, num_receivers: i
return target_receivers
def state_dict_to(state_dict: Dict[str, Any],
dtype: torch.dtype = torch.float16,
device: torch.device = torch.device('cpu')):
'''
keep state_dict intact
'''
def state_dict_to(
state_dict: Dict[str, Any], dtype: torch.dtype = torch.float16, device: torch.device = torch.device("cpu")
):
"""
keep state_dict intact
"""
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_state_dict[k] = v.to(dtype=dtype, device=device)

View File

@ -3,8 +3,4 @@ from .ppo import PPOTrainer
from .rm import RewardModelTrainer
from .sft import SFTTrainer
__all__ = [
'SLTrainer', 'OnPolicyTrainer',
'RewardModelTrainer', 'SFTTrainer',
'PPOTrainer'
]
__all__ = ["SLTrainer", "OnPolicyTrainer", "RewardModelTrainer", "SFTTrainer", "PPOTrainer"]

View File

@ -68,12 +68,14 @@ class OnPolicyTrainer(ABC):
callbacks (List[Callback], defaults to []): the callbacks to call during training process
"""
def __init__(self,
strategy: Strategy,
data_buffer: NaiveExperienceBuffer,
sample_buffer: bool,
dataloader_pin_memory: bool,
callbacks: List[Callback] = []) -> None:
def __init__(
self,
strategy: Strategy,
data_buffer: NaiveExperienceBuffer,
sample_buffer: bool,
dataloader_pin_memory: bool,
callbacks: List[Callback] = [],
) -> None:
super().__init__()
self.strategy = strategy
self.data_buffer = data_buffer

View File

@ -2,4 +2,4 @@ from .base import Callback
from .performance_evaluator import PerformanceEvaluator
from .save_checkpoint import SaveCheckpoint
__all__ = ['Callback', 'PerformanceEvaluator', 'SaveCheckpoint']
__all__ = ["Callback", "PerformanceEvaluator", "SaveCheckpoint"]

View File

@ -5,7 +5,7 @@ from coati.experience_maker import Experience
class Callback(ABC):
"""
Base callback class. It defines the interface for callbacks.
Base callback class. It defines the interface for callbacks.
"""
def on_fit_start(self) -> None:

View File

@ -21,9 +21,9 @@ def print_rank_0(*args, **kwargs) -> None:
def divide(x: float, y: float) -> float:
if y == 0:
return float('inf')
elif y == float('inf'):
return float('nan')
return float("inf")
elif y == float("inf"):
return float("nan")
return x / y
@ -38,10 +38,9 @@ def all_reduce_mean(x: float, world_size: int) -> float:
class Timer:
def __init__(self) -> None:
self.start_time: Optional[float] = None
self.duration: float = 0.
self.duration: float = 0.0
def start(self) -> None:
self.start_time = time()
@ -52,7 +51,7 @@ class Timer:
self.start_time = None
def reset(self) -> None:
self.duration = 0.
self.duration = 0.0
class PerformanceEvaluator(Callback):
@ -67,13 +66,15 @@ class PerformanceEvaluator(Callback):
ignore_episodes: The number of episodes to ignore when calculating the performance.
"""
def __init__(self,
actor_num_params: int,
critic_num_params: int,
initial_model_num_params: int,
reward_model_num_params: int,
enable_grad_checkpoint: bool = False,
ignore_episodes: int = 0) -> None:
def __init__(
self,
actor_num_params: int,
critic_num_params: int,
initial_model_num_params: int,
reward_model_num_params: int,
enable_grad_checkpoint: bool = False,
ignore_episodes: int = 0,
) -> None:
super().__init__()
self.world_size = get_world_size()
self.actor_num_params = actor_num_params
@ -155,8 +156,9 @@ class PerformanceEvaluator(Callback):
avg_learn_duration = all_reduce_mean(self.learn_timer.duration, self.world_size)
avg_overall_duration = all_reduce_mean(self.overall_timer.duration, self.world_size)
avg_make_experience_throughput = self.make_experience_num_samples * \
self.world_size / (avg_make_experience_duration + 1e-12)
avg_make_experience_throughput = (
self.make_experience_num_samples * self.world_size / (avg_make_experience_duration + 1e-12)
)
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
avg_learn_throughput = self.learn_num_samples * self.world_size / (avg_learn_duration + 1e-12)
@ -171,13 +173,11 @@ class PerformanceEvaluator(Callback):
learn_time_per_sample = divide(avg_learn_duration, num_effective_samples)
print_rank_0(
f'Performance summary:\n'
+ f'Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n'
+ f'Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n'
+ f'Overall throughput: {avg_overall_throughput:.2f} samples/s\n'
+ f'Overall time per sample: {overall_time_per_sample:.2f} s\n'
+ f'Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n'
+ f'Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%'
f"Performance summary:\n"
+ f"Generate {self.make_experience_num_samples * self.world_size} samples, throughput: {avg_make_experience_throughput:.2f} samples/s, TFLOPS per GPU: {avg_make_experience_tflops:.2f}\n"
+ f"Train {self.learn_num_samples * self.world_size} samples, throughput: {avg_learn_throughput:.2f} samples/s, TFLOPS per GPU: {avg_learn_tflops:.2f}\n"
+ f"Overall throughput: {avg_overall_throughput:.2f} samples/s\n"
+ f"Overall time per sample: {overall_time_per_sample:.2f} s\n"
+ f"Make experience time per sample: {make_experience_time_per_sample:.2f} s, {make_experience_time_per_sample/overall_time_per_sample*100:.2f}%\n"
+ f"Learn time per sample: {learn_time_per_sample:.2f} s, {learn_time_per_sample/overall_time_per_sample*100:.2f}%"
)

View File

@ -36,34 +36,35 @@ class SaveCheckpoint(Callback):
"""
def __init__(self,
path: str,
interval: int,
strategy: Strategy,
actor: nn.Module = None,
critic: nn.Module = None,
actor_optim: Optimizer = None,
critic_optim: Optimizer = None) -> None:
def __init__(
self,
path: str,
interval: int,
strategy: Strategy,
actor: nn.Module = None,
critic: nn.Module = None,
actor_optim: Optimizer = None,
critic_optim: Optimizer = None,
) -> None:
super().__init__()
self.path = os.path.join(path, 'checkpoint')
self.path = os.path.join(path, "checkpoint")
self.interval = interval
self.strategy = strategy
self.model_dict = {'actor': [actor, actor_optim], 'critic': [critic, critic_optim]}
self.model_dict = {"actor": [actor, actor_optim], "critic": [critic, critic_optim]}
def on_episode_end(self, episode: int) -> None:
if (episode + 1) % self.interval != 0:
return
base_path = os.path.join(self.path, f'episode_{episode}')
base_path = os.path.join(self.path, f"episode_{episode}")
if not os.path.exists(base_path):
os.makedirs(base_path)
for model in self.model_dict.keys():
# save model
if self.model_dict[model][0] is None:
# saving only optimizer states is meaningless, so it would be skipped
continue
model_path = os.path.join(base_path, f'{model}.pt')
model_path = os.path.join(base_path, f"{model}.pt")
self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True)
# save optimizer
@ -71,5 +72,5 @@ class SaveCheckpoint(Callback):
continue
only_rank0 = not isinstance(self.strategy, (LowLevelZeroStrategy, GeminiStrategy))
rank = 0 if is_rank_0() else dist.get_rank()
optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt')
optim_path = os.path.join(base_path, f"{model}-optim-rank-{rank}.pt")
self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0)

View File

@ -8,7 +8,7 @@ from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
from coati.models.utils import calc_action_log_probs
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data import DistributedSampler
from tqdm import tqdm
from colossalai.utils import get_current_device
@ -24,11 +24,11 @@ def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, acto
hf_model = get_base_model(unwrapper_model)
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation
if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"):
new_kwargs["prepare_inputs_fn"] = hf_model.prepare_inputs_for_generation
if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'):
new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation
if "update_model_kwargs_fn" not in generate_kwargs and hasattr(hf_model, "_update_model_kwargs_for_generation"):
new_kwargs["update_model_kwargs_fn"] = hf_model._update_model_kwargs_for_generation
return new_kwargs
@ -60,38 +60,34 @@ class PPOTrainer(OnPolicyTrainer):
generate_kwargs (dict, optional): the kwargs to use while model generating
"""
def __init__(self,
strategy: Strategy,
actor: Actor,
critic: Critic,
reward_model: nn.Module,
initial_model: Actor,
actor_optim: Optimizer,
critic_optim: Optimizer,
kl_coef: float = 0.1,
ptx_coef: float = 0.9,
train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
eps_clip: float = 0.2,
vf_coef: float = 1.0,
value_clip: float = 0.4,
sample_buffer: bool = False,
dataloader_pin_memory: bool = True,
offload_inference_models: bool = True,
callbacks: List[Callback] = [],
**generate_kwargs
) -> None:
def __init__(
self,
strategy: Strategy,
actor: Actor,
critic: Critic,
reward_model: nn.Module,
initial_model: Actor,
actor_optim: Optimizer,
critic_optim: Optimizer,
kl_coef: float = 0.1,
ptx_coef: float = 0.9,
train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
eps_clip: float = 0.2,
vf_coef: float = 1.0,
value_clip: float = 0.4,
sample_buffer: bool = False,
dataloader_pin_memory: bool = True,
offload_inference_models: bool = True,
callbacks: List[Callback] = [],
**generate_kwargs,
) -> None:
if isinstance(strategy, GeminiStrategy):
assert not offload_inference_models, \
"GeminiPlugin is not compatible with manual model.to('cpu')"
assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')"
data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
super().__init__(
strategy, data_buffer,
sample_buffer, dataloader_pin_memory,
callbacks
)
super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks)
self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
@ -130,18 +126,16 @@ class PPOTrainer(OnPolicyTrainer):
num_actions = experience.action_mask.size(1)
actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask)
action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions)
actor_loss = self.actor_loss_fn(action_log_probs,
experience.action_log_probs,
experience.advantages,
action_mask=experience.action_mask)
actor_loss = self.actor_loss_fn(
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
)
# ptx loss
if self.ptx_coef != 0:
batch = self.pretrain_dataloader.next()
batch = to_device(batch, self.device)
ptx_log_probs = self.actor(batch['input_ids'],
attention_mask=batch['attention_mask'])['logits']
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels'])
ptx_log_probs = self.actor(batch["input_ids"], attention_mask=batch["attention_mask"])["logits"]
ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch["labels"])
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
@ -149,24 +143,23 @@ class PPOTrainer(OnPolicyTrainer):
self.actor_optim.zero_grad()
# value loss
values = self.critic(experience.sequences,
action_mask=experience.action_mask,
attention_mask=experience.attention_mask)
critic_loss = self.critic_loss_fn(values,
experience.values,
experience.reward,
action_mask=experience.action_mask)
values = self.critic(
experience.sequences, action_mask=experience.action_mask, attention_mask=experience.attention_mask
)
critic_loss = self.critic_loss_fn(
values, experience.values, experience.reward, action_mask=experience.action_mask
)
critic_loss = critic_loss * self.vf_coef
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim)
self.critic_optim.zero_grad()
return {'reward': experience.reward.mean().item()}
return {"reward": experience.reward.mean().item()}
def _learn(self, update_step: int):
if self.offload_inference_models:
self.experience_maker.initial_model.to('cpu')
self.experience_maker.reward_model.to('cpu')
self.experience_maker.initial_model.to("cpu")
self.experience_maker.reward_model.to("cpu")
# buffer may be empty at first, we should rebuild at each training
if self.sample_buffer:
@ -178,11 +171,7 @@ class PPOTrainer(OnPolicyTrainer):
else:
if isinstance(self.dataloader.sampler, DistributedSampler):
self.dataloader.sampler.set_epoch(update_step)
pbar = tqdm(
self.dataloader,
desc=f'Train epoch [{update_step + 1}]',
disable=not is_rank_0()
)
pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0())
for experience in pbar:
self._on_learn_batch_start()
experience.to_device(self.device)

View File

@ -62,18 +62,15 @@ class RewardModelTrainer(SLTrainer):
if is_rank_0():
log = pd.DataFrame(
[[(epoch + 1) * len(self.train_dataloader),
self.loss.item(), self.dist, self.acc]],
columns=['step', 'loss', 'dist', 'acc']
[[(epoch + 1) * len(self.train_dataloader), self.loss.item(), self.dist, self.acc]],
columns=["step", "loss", "dist", "acc"],
)
log.to_csv('log.csv', mode='a', header=False, index=False)
log.to_csv("log.csv", mode="a", header=False, index=False)
def _train(self, epoch):
self.model.train()
step_bar = tqdm.trange(
len(self.train_dataloader),
desc='Train step of epoch %d' % epoch,
disable=not is_rank_0()
len(self.train_dataloader), desc="Train step of epoch %d" % epoch, disable=not is_rank_0()
)
cnt = 0
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
@ -93,10 +90,7 @@ class RewardModelTrainer(SLTrainer):
step_bar.update()
step_bar.close()
def _before_fit(self,
train_dataloader: DataLoader,
valid_dataloader: DataLoader,
eval_dataloader: DataLoader):
def _before_fit(self, train_dataloader: DataLoader, valid_dataloader: DataLoader, eval_dataloader: DataLoader):
"""
Args:
train_dataloader (DataLoader): the dataloader to use for training
@ -104,7 +98,7 @@ class RewardModelTrainer(SLTrainer):
eval_dataloader (DataLoader): the dataloader to use for evaluation
"""
super()._before_fit()
self.datetime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
self.datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
self.train_dataloader = train_dataloader
self.valid_dataloader = valid_dataloader

View File

@ -39,8 +39,9 @@ class SFTTrainer(SLTrainer):
accumulation_steps: int = 8,
) -> None:
if accumulation_steps > 1:
assert not isinstance(strategy, GeminiStrategy), \
"Accumulation steps are not supported in stage 3 of ColossalAI"
assert not isinstance(
strategy, GeminiStrategy
), "Accumulation steps are not supported in stage 3 of ColossalAI"
super().__init__(strategy, max_epochs, model, optim)
@ -50,15 +51,11 @@ class SFTTrainer(SLTrainer):
def _train(self, epoch: int):
self.model.train()
for batch_id, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
if "attention_mask" in batch:
outputs = self.model(batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"])
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
else:
outputs = self.model(batch["input_ids"],
labels=batch["labels"])
outputs = self.model(batch["input_ids"], labels=batch["labels"])
loss = outputs.loss
loss = loss / self.accumulation_steps
@ -73,12 +70,14 @@ class SFTTrainer(SLTrainer):
self.optimizer.zero_grad()
self.scheduler.step()
if is_rank_0() and self.use_wandb:
wandb.log({
"loss": self.total_loss / self.accumulation_steps,
"lr": self.scheduler.get_last_lr()[0],
"epoch": epoch,
"batch_id": batch_id
})
wandb.log(
{
"loss": self.total_loss / self.accumulation_steps,
"lr": self.scheduler.get_last_lr()[0],
"epoch": epoch,
"batch_id": batch_id,
}
)
self.total_loss = 0
self.step_bar.update()
@ -89,9 +88,9 @@ class SFTTrainer(SLTrainer):
loss_sum, num_seen = 0, 0
for batch in self.eval_dataloader:
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"])
outputs = self.model(
batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]
)
loss = outputs.loss
loss_sum += loss.item()
@ -99,13 +98,15 @@ class SFTTrainer(SLTrainer):
loss_mean = loss_sum / num_seen
if dist.get_rank() == 0:
self.logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}')
self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}")
def _before_fit(self,
train_dataloader: DataLoader,
eval_dataloader: Optional[DataLoader] = None,
logger: Optional[DistributedLogger] = None,
use_wandb: bool = False):
def _before_fit(
self,
train_dataloader: DataLoader,
eval_dataloader: Optional[DataLoader] = None,
logger: Optional[DistributedLogger] = None,
use_wandb: bool = False,
):
"""
Args:
train_dataloader: the dataloader to use for training
@ -124,6 +125,6 @@ class SFTTrainer(SLTrainer):
self.no_epoch_bar = True
self.step_bar = tqdm.trange(
len(self.train_dataloader) // self.accumulation_steps * self.max_epochs,
desc=f'steps',
disable=not is_rank_0()
desc=f"steps",
disable=not is_rank_0(),
)

View File

@ -2,7 +2,4 @@ from .base import Strategy
from .colossalai import GeminiStrategy, LowLevelZeroStrategy
from .ddp import DDPStrategy
__all__ = [
'Strategy', 'DDPStrategy',
'LowLevelZeroStrategy', 'GeminiStrategy'
]
__all__ = ["Strategy", "DDPStrategy", "LowLevelZeroStrategy", "GeminiStrategy"]

View File

@ -19,7 +19,7 @@ _BoostArgSpec = Union[nn.Module, Tuple[nn.Module, Optimizer], Dict]
class Strategy(ABC):
"""
Base class for training strategies.
Base class for training strategies.
"""
def __init__(self, plugin_initializer: Callable[..., Optional[Plugin]] = lambda: None) -> None:
@ -83,16 +83,18 @@ class Strategy(ABC):
rets.append((model, optimizer))
elif isinstance(arg, Dict):
model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg)
boost_result = dict(model=model,
optimizer=optimizer,
criterion=criterion,
dataloader=dataloader,
lr_scheduler=lr_scheduler)
boost_result = dict(
model=model,
optimizer=optimizer,
criterion=criterion,
dataloader=dataloader,
lr_scheduler=lr_scheduler,
)
# remove None values
boost_result = {key: value for key, value in boost_result.items() if value is not None}
rets.append(boost_result)
else:
raise RuntimeError(f'Type {type(arg)} is not supported')
raise RuntimeError(f"Type {type(arg)} is not supported")
return rets[0] if len(rets) == 1 else rets
@ -125,11 +127,9 @@ class Strategy(ABC):
return DistributedSampler(dataset, 1, 0)
@abstractmethod
def save_pretrained(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
def save_pretrained(
self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None
) -> None:
pass
@abstractmethod

View File

@ -42,27 +42,27 @@ class LowLevelZeroStrategy(DDPStrategy):
"""
def __init__(self,
stage: int = 2,
precision: str = 'fp16',
seed: int = 42,
placement_policy: str = 'cuda',
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
overlap_communication: bool = True, # only for stage 1&2
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0
) -> None:
def __init__(
self,
stage: int = 2,
precision: str = "fp16",
seed: int = 42,
placement_policy: str = "cuda",
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
overlap_communication: bool = True, # only for stage 1&2
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
) -> None:
assert stage in (1, 2), f'Unsupported stage "{stage}"'
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"'
assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"'
assert precision in ("fp32", "fp16"), f'Unsupported precision "{precision}"'
plugin_initializer = lambda: LowLevelZeroPlugin(
# zero_config
@ -71,7 +71,7 @@ class LowLevelZeroStrategy(DDPStrategy):
# zero_optim_config
reduce_bucket_size_in_m=reduce_bucket_size,
overlap_communication=overlap_communication,
cpu_offload=(placement_policy == 'cpu'),
cpu_offload=(placement_policy == "cpu"),
# optim_config
initial_scale=initial_scale,
growth_factor=growth_factor,
@ -81,14 +81,15 @@ class LowLevelZeroStrategy(DDPStrategy):
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type
norm_type=norm_type,
)
super().__init__(seed, plugin_initializer)
def _post_init(self) -> None:
assert isinstance(self.plugin, LowLevelZeroPlugin), \
f'{type(self).__name__}\'s plugin is not initialized properly.'
assert isinstance(
self.plugin, LowLevelZeroPlugin
), f"{type(self).__name__}'s plugin is not initialized properly."
def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed)
@ -131,45 +132,45 @@ class GeminiStrategy(DDPStrategy):
"""
def __init__(self,
seed: int = 42,
shard_init: bool = False, # only for stage 3
placement_policy: str = 'cuda',
pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3
search_range_m: int = 32, # only for stage 3
hidden_dim: Optional[int] = None, # only for stage 3
min_chunk_size_m: float = 32, # only for stage 3
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0
) -> None:
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
def __init__(
self,
seed: int = 42,
shard_init: bool = False, # only for stage 3
placement_policy: str = "cuda",
pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3
search_range_m: int = 32, # only for stage 3
hidden_dim: Optional[int] = None, # only for stage 3
min_chunk_size_m: float = 32, # only for stage 3
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
) -> None:
assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"'
# TODO(ver217): support shard_init when using from_pretrained()
if shard_init:
warnings.warn(
f'Shard init is not supported model.from_pretrained() yet. '
'Please load weights after strategy.prepare()'
f"Shard init is not supported model.from_pretrained() yet. "
"Please load weights after strategy.prepare()"
)
self.shard_init = shard_init
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.")
# NOTE: dist should be initialized before calling get_current_device()
plugin_initializer = lambda: GeminiPlugin(
# gemini_config
device=get_current_device(),
placement_policy=placement_policy,
precision='fp16',
precision="fp16",
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=shard_init,
@ -187,14 +188,13 @@ class GeminiStrategy(DDPStrategy):
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type
norm_type=norm_type,
)
super().__init__(seed, plugin_initializer)
def _post_init(self) -> None:
assert isinstance(self.plugin, GeminiPlugin), \
f'{type(self).__name__}\'s plugin is not initialized properly.'
assert isinstance(self.plugin, GeminiPlugin), f"{type(self).__name__}'s plugin is not initialized properly."
def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed)
@ -203,10 +203,9 @@ class GeminiStrategy(DDPStrategy):
world_size = dist.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
return ColoInitContext(device=get_current_device(),
dtype=torch.half,
default_pg=shard_pg,
default_dist_spec=default_dist_spec)
return ColoInitContext(
device=get_current_device(), dtype=torch.half, default_pg=shard_pg, default_dist_spec=default_dist_spec
)
def unwrap_model(self, model: nn.Module) -> nn.Module:
assert isinstance(model, GeminiModel)

View File

@ -31,24 +31,21 @@ def get_grad_required_state_dict(model: nn.Module):
class DDPStrategy(Strategy):
"""
Strategy for distributed training using torch.distributed.
Strategy for distributed training using torch.distributed.
"""
def __init__(self,
seed: int = 42,
plugin_initializer: Callable = TorchDDPPlugin
) -> None:
def __init__(self, seed: int = 42, plugin_initializer: Callable = TorchDDPPlugin) -> None:
self.seed = seed
super().__init__(plugin_initializer)
def _try_init_dist(self, force: bool = False) -> None:
try:
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
host = os.environ['MASTER_ADDR']
port = int(os.environ['MASTER_PORT'])
dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
host = os.environ["MASTER_ADDR"]
port = int(os.environ["MASTER_PORT"])
dist.init_process_group("nccl", init_method=f"tcp://[{host}]:{port}", world_size=world_size, rank=rank)
torch.cuda.set_device(local_rank)
except KeyError as e:
if force:
@ -60,8 +57,7 @@ class DDPStrategy(Strategy):
raise e
def _post_init(self) -> None:
assert isinstance(self.plugin, TorchDDPPlugin), \
f'{type(self).__name__}\'s plugin is not initialized properly.'
assert isinstance(self.plugin, TorchDDPPlugin), f"{type(self).__name__}'s plugin is not initialized properly."
def setup_distributed(self) -> None:
self._try_init_dist(force=True)
@ -73,12 +69,14 @@ class DDPStrategy(Strategy):
torch.manual_seed(seed)
def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
return self.plugin.prepare_dataloader(data_buffer,
batch_size=data_buffer.sample_batch_size,
shuffle=True,
drop_last=True,
pin_memory=pin_memory,
collate_fn=data_buffer.collate_fn)
return self.plugin.prepare_dataloader(
data_buffer,
batch_size=data_buffer.sample_batch_size,
shuffle=True,
drop_last=True,
pin_memory=pin_memory,
collate_fn=data_buffer.collate_fn,
)
def setup_sampler(self, dataset) -> DistributedSampler:
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
@ -88,11 +86,9 @@ class DDPStrategy(Strategy):
assert isinstance(model, TorchDDPModel), "model is not wrapped by TorchDDPModel."
return model.unwrap()
def save_pretrained(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
def save_pretrained(
self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None
) -> None:
if not only_rank0 or dist.get_rank() == 0:
unwrapped_model = self.unwrap_model(model)
assert isinstance(unwrapped_model, (Actor, Critic, RewardModel))
@ -103,17 +99,11 @@ class DDPStrategy(Strategy):
if tokenizer is not None:
tokenizer.save_pretrained(path)
model_path = os.path.join(path, "pytorch_model.bin")
self.save_model(model,
model_path,
only_rank0=only_rank0)
self.save_model(model, model_path, only_rank0=only_rank0)
def _replace_keys(model_path: str,
replace_fn: Callable):
def _replace_keys(model_path: str, replace_fn: Callable):
state_dict = torch.load(model_path, map_location="cpu")
state_dict = {
replace_fn(k): v
for k, v in state_dict.items()
}
state_dict = {replace_fn(k): v for k, v in state_dict.items()}
torch.save(state_dict, model_path)
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
@ -124,13 +114,13 @@ class DDPStrategy(Strategy):
def get_model_state_dict_shard(self, model: nn.Module, **config):
# TODO: implement sharding on naive strategy
model = self.unwrap_model(model)
if 'requires_grad_only' in config and config['requires_grad_only'] == True:
if "requires_grad_only" in config and config["requires_grad_only"] == True:
state_dict = get_grad_required_state_dict(model)
else:
state_dict = model.state_dict()
if 'shard_size' in config:
shard_size = config['shard_size']
if "shard_size" in config:
shard_size = config["shard_size"]
accumulate_size = 0
state_dict_shard = OrderedDict()
for name, param in state_dict.items():

View File

@ -4,7 +4,6 @@ import numpy as np
class DistributedSampler:
def __init__(self, dataset, num_replicas: int, rank: int) -> None:
self.dataset = dataset
self.num_replicas = num_replicas
@ -12,7 +11,7 @@ class DistributedSampler:
if len(self.dataset) % self.num_replicas != 0:
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
@ -20,10 +19,10 @@ class DistributedSampler:
self.total_size = self.num_samples * self.num_replicas
indices = list(range(len(self.dataset)))
indices = indices[:self.total_size]
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
self.indices = indices

View File

@ -42,7 +42,6 @@ def is_rank_0() -> bool:
def to_device(x: Any, device: torch.device) -> Any:
def _to(t: Any):
if isinstance(t, torch.Tensor):
return t.to(device)

View File

@ -70,7 +70,7 @@
"BLEU",
"ROUGE",
"BERTScore"
]
]
},
"logical_reasoning": {
"GPT": [
@ -83,7 +83,7 @@
"ROUGE",
"BERTScore",
"CHRF"
]
]
},
"open_qa": {
"GPT": [
@ -126,7 +126,7 @@
"conciseness"
],
"Metrics": [
]
]
},
"Finance": {
"GPT": [
@ -134,7 +134,7 @@
"correctness"
],
"Metrics": [
]
]
},
"Law": {
"GPT": [
@ -142,7 +142,7 @@
"correctness"
],
"Metrics": [
]
]
},
"Education": {
"GPT": [
@ -150,7 +150,7 @@
"correctness"
],
"Metrics": [
]
]
},
"Medical": {
"GPT": [
@ -158,7 +158,7 @@
"correctness"
],
"Metrics": [
]
]
},
"STEM": {
"GPT": [
@ -166,7 +166,7 @@
"correctness"
],
"Metrics": [
]
]
},
"SocialScience": {
"GPT": [
@ -174,7 +174,7 @@
"correctness"
],
"Metrics": [
]
]
},
"Humanity": {
"GPT": [
@ -182,7 +182,7 @@
"correctness"
],
"Metrics": [
]
]
},
"Other": {
"GPT": [
@ -190,7 +190,7 @@
"correctness"
],
"Metrics": [
]
]
},
"ethics": {
"GPT": [
@ -198,7 +198,7 @@
"correctness"
],
"Metrics": [
]
]
}
}
}

View File

@ -1,5 +1,4 @@
import argparse
import json
import os
import openai
@ -9,7 +8,8 @@ from utils import jload
def main(args):
assert len(args.answer_file_list) == len(
args.model_name_list), "The number of answer files and model names should be equal!"
args.model_name_list
), "The number of answer files and model names should be equal!"
# load config
config = jload(args.config_file)
@ -36,7 +36,8 @@ def main(args):
if len(args.model_name_list) == 1 and not gpt_evaluation_prompt:
raise Exception(
"No prompt file for gpt evaluation provided. Please specify the prompt file for gpt evaluation!")
"No prompt file for gpt evaluation provided. Please specify the prompt file for gpt evaluation!"
)
if args.gpt_model == "text-davinci-003" and args.gpt_with_reference:
raise Exception(
@ -44,8 +45,15 @@ def main(args):
)
# initialize evaluator
evaluator = Evaluator(metrics_per_category, battle_prompt, gpt_evaluation_prompt, args.gpt_model,
config["language"], config.get("path_for_UniEval", None), args.gpt_with_reference)
evaluator = Evaluator(
metrics_per_category,
battle_prompt,
gpt_evaluation_prompt,
args.gpt_model,
config["language"],
config.get("path_for_UniEval", None),
args.gpt_with_reference,
)
if len(args.model_name_list) == 2:
answers1 = jload(args.answer_file_list[0])
answers2 = jload(args.answer_file_list[1])
@ -68,41 +76,41 @@ def main(args):
raise ValueError(f'Unsupported language {config["language"]}!')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='ColossalAI LLM evaluation pipeline.')
parser.add_argument('--config_file',
type=str,
default=None,
required=True,
help='path to the file of target results')
parser.add_argument('--battle_prompt_file', type=str, default=None, help='path to the prompt file for battle')
parser.add_argument('--gpt_evaluation_prompt_file',
type=str,
default=None,
help='path to the prompt file for gpt evaluation')
parser.add_argument('--target_file', type=str, default=None, help='path to the target answer (ground truth) file')
parser.add_argument('--answer_file_list',
type=str,
nargs='+',
default=[],
required=True,
help='path to the answer files of at most 2 models')
parser.add_argument('--model_name_list',
type=str,
nargs='+',
default=[],
required=True,
help='the names of at most 2 models')
parser.add_argument('--gpt_model',
default="gpt-3.5-turbo",
choices=["text-davinci-003", "gpt-3.5-turbo", "gpt-4"],
help='which GPT model to use for evaluation')
parser.add_argument('--gpt_with_reference',
default=False,
action="store_true",
help='whether to include reference answer in gpt evaluation')
parser.add_argument('--save_path', type=str, default="results", help='path to save evaluation results')
parser.add_argument('--openai_key', type=str, default=None, required=True, help='Your openai key')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="ColossalAI LLM evaluation pipeline.")
parser.add_argument(
"--config_file", type=str, default=None, required=True, help="path to the file of target results"
)
parser.add_argument("--battle_prompt_file", type=str, default=None, help="path to the prompt file for battle")
parser.add_argument(
"--gpt_evaluation_prompt_file", type=str, default=None, help="path to the prompt file for gpt evaluation"
)
parser.add_argument("--target_file", type=str, default=None, help="path to the target answer (ground truth) file")
parser.add_argument(
"--answer_file_list",
type=str,
nargs="+",
default=[],
required=True,
help="path to the answer files of at most 2 models",
)
parser.add_argument(
"--model_name_list", type=str, nargs="+", default=[], required=True, help="the names of at most 2 models"
)
parser.add_argument(
"--gpt_model",
default="gpt-3.5-turbo",
choices=["text-davinci-003", "gpt-3.5-turbo", "gpt-4"],
help="which GPT model to use for evaluation",
)
parser.add_argument(
"--gpt_with_reference",
default=False,
action="store_true",
help="whether to include reference answer in gpt evaluation",
)
parser.add_argument("--save_path", type=str, default="results", help="path to save evaluation results")
parser.add_argument("--openai_key", type=str, default=None, required=True, help="Your openai key")
args = parser.parse_args()
if args.openai_key is not None:

View File

@ -3,20 +3,27 @@ from typing import Any, Dict, List
import gpt_evaluate
import metrics
import pandas as pd
import unieval
from utils import analyze_automatic_results, get_data_per_category, save_automatic_results
class Evaluator(object):
"""
A class named Evaluator includes GPT-3.5/GPT-4 evaluation
and automatic evaluation
A class named Evaluator includes GPT-3.5/GPT-4 evaluation
and automatic evaluation
"""
def __init__(self, params: Dict[str, Any], battle_prompt: Dict[str, Any], gpt_evaluation_prompt: Dict[str, Any],
gpt_model: str, language: str, path_for_UniEval: Dict[str, str], gpt_with_reference: bool) -> None:
def __init__(
self,
params: Dict[str, Any],
battle_prompt: Dict[str, Any],
gpt_evaluation_prompt: Dict[str, Any],
gpt_model: str,
language: str,
path_for_UniEval: Dict[str, str],
gpt_with_reference: bool,
) -> None:
self.params = params
self.battle_prompt = battle_prompt
self.gpt_evaluation_prompt = gpt_evaluation_prompt
@ -103,7 +110,8 @@ class Evaluator(object):
if self.params[category]["UniEval"] and self.language == "cn":
raise Exception(
"UniEval doesn't support Chinese! Please remove UniEval config in your Chinese config file.")
"UniEval doesn't support Chinese! Please remove UniEval config in your Chinese config file."
)
category_metrics = self.params[category]["UniEval"]
@ -134,10 +142,9 @@ class Evaluator(object):
sources_list = [answer["instruction"] + answer["input"] for answer in answers_per_category[category]]
data = unieval.convert_data_to_unieval_format(predicts_list, sources_list, targets_list)
scores = uni_evaluator.evaluate(data,
category,
dims=list(self.unieval_metric_stats[task][category].keys()),
overall=False)
scores = uni_evaluator.evaluate(
data, category, dims=list(self.unieval_metric_stats[task][category].keys()), overall=False
)
avg_scores = unieval.calculate_average_score(scores)
self.unieval_metric_stats[task][category].update(avg_scores)
@ -165,7 +172,8 @@ class Evaluator(object):
category,
self.gpt_model,
self.language,
references=targets_per_category[category] if self.gpt_with_reference else None)
references=targets_per_category[category] if self.gpt_with_reference else None,
)
def save(self, path: str, model_name_list: List[str]) -> None:
"""
@ -204,16 +212,18 @@ class Evaluator(object):
gpt_base_save_path = os.path.join(path, "gpt_evaluate", "gpt_evaluate_results")
gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results")
all_evaluations = gpt_evaluate.save_gpt_evaluation_results(model_name_list[0],
self.gpt_evaluation_results,
gpt_evaluation_results_save_path)
all_evaluations = gpt_evaluate.save_gpt_evaluation_results(
model_name_list[0], self.gpt_evaluation_results, gpt_evaluation_results_save_path
)
# Start to calculate scores and save statistics.
gpt_evaluation_statistics_save_path = os.path.join(gpt_base_save_path, "evaluation_statistics")
gpt_evaluate.save_gpt_evaluation_statistics(model_name_list[0], all_evaluations,
gpt_evaluation_statistics_save_path)
gpt_evaluate.save_gpt_evaluation_statistics(
model_name_list[0], all_evaluations, gpt_evaluation_statistics_save_path
)
# Save charts and csv.
gpt_evaluation_analyses_save_path = os.path.join(gpt_base_save_path, "evaluation_analyses")
gpt_evaluate.analyze_gpt_evaluation_statistics(gpt_evaluation_statistics_save_path,
gpt_evaluation_analyses_save_path)
gpt_evaluate.analyze_gpt_evaluation_statistics(
gpt_evaluation_statistics_save_path, gpt_evaluation_analyses_save_path
)

View File

@ -14,20 +14,18 @@ import tqdm
from utils import jdump, jload
ref_step_template = {
"en":
"Now please compare the answer with the {adjective} answer, determine whether the answer is able to achieve the same level of {metric}.\n\n",
"cn":
"请比较答案与上面的{adjective}答案,确定答案是否可以达到与该{adjective}答案同样水平的{metric}\n\n"
"en": "Now please compare the answer with the {adjective} answer, determine whether the answer is able to achieve the same level of {metric}.\n\n",
"cn": "请比较答案与上面的{adjective}答案,确定答案是否可以达到与该{adjective}答案同样水平的{metric}\n\n",
}
ref_answer_template_general = {
"en": "\nAn example answer with good quality is as follows:\n\n{answer}\n\n",
"cn": "\n一个优质的示例答案如下:\n\n{answer}\n\n"
"cn": "\n一个优质的示例答案如下:\n\n{answer}\n\n",
}
ref_answer_template_correctness = {
"en": "\nA correct answer is as follows:\n\n{answer}\n\n",
"cn": "\n标准答案如下:\n\n{answer}\n\n"
"cn": "\n标准答案如下:\n\n{answer}\n\n",
}
@ -51,10 +49,7 @@ def get_battle_result(sys_prompt: str, user_prompt: str, id: int, max_tokens: in
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{
"role": "system",
"content": sys_prompt
},
{"role": "system", "content": sys_prompt},
{
"role": "user",
"content": user_prompt,
@ -106,7 +101,7 @@ def parse_battle_score(evaluation: str) -> List[float]:
return [float(sp[0]), float(sp[1])]
else:
raise Exception(f"Invalid score pair. Got {evaluation}.")
except Exception as e:
except Exception:
return [-1, -1]
@ -125,9 +120,6 @@ def battle(answer1: List[Dict], answer2: List[Dict], prompt_dict: Dict[str, Any]
assert len(answer1) == len(answer2)
handles = []
evaluation_file = []
total_len = len(answer1)
question_idx_list = list(range(total_len))
@ -140,9 +132,12 @@ def battle(answer1: List[Dict], answer2: List[Dict], prompt_dict: Dict[str, Any]
assert answer1[i]["id"] == answer2[i]["id"]
answer_id = answer1[i]["id"]
ques = answer1[i]["instruction"] if answer1[i][
"input"] == "" else answer1[i]["instruction"] + " " + answer1[i]["input"]
cat = answer1[i]["category"]
ques = (
answer1[i]["instruction"]
if answer1[i]["input"] == ""
else answer1[i]["instruction"] + " " + answer1[i]["input"]
)
answer1[i]["category"]
ans1 = answer1[i]["output"]
ans2 = answer2[i]["output"]
@ -267,7 +262,11 @@ def reference_template(metric: str, language: str, reference: Dict[str, Any]) ->
step_to_add = ref_step_template[language]
for_the_given_answer = "{metric} (1-5) (directly give the score for the given answer):" if language == "en" else "{metric} (1-5) (直接对给定答案打分)"
for_the_given_answer = (
"{metric} (1-5) (directly give the score for the given answer):"
if language == "en"
else "{metric} (1-5) (直接对给定答案打分)"
)
# adjective is used to describe the word "answer" in the prompt.
adjective = "example" if language == "en" else "示例"
@ -280,8 +279,9 @@ def reference_template(metric: str, language: str, reference: Dict[str, Any]) ->
answer_to_add = ref_answer_template_correctness[language]
answer_to_add = answer_to_add.format(answer=reference["target"] if reference["target"] else reference["output"])
step_to_add = step_to_add.format(metric=metric.lower(),
adjective=adjective) + for_the_given_answer.format(metric=metric)
step_to_add = step_to_add.format(metric=metric.lower(), adjective=adjective) + for_the_given_answer.format(
metric=metric
)
return answer_to_add + step_to_add
@ -329,7 +329,8 @@ def multiturn_chat_completion(user_messages: List[str], model: str, max_tokens:
for j in range(i):
messages_to_send.append(fill_in_message("user", user_messages[j]))
messages_to_send.append(
fill_in_message("assistant", assistant_responses[j]["choices"][0]["message"]["content"]))
fill_in_message("assistant", assistant_responses[j]["choices"][0]["message"]["content"])
)
# Length of user messages == Length of assistant messages + 1
# Because we always expect the api to response
@ -351,13 +352,15 @@ def multiturn_chat_completion(user_messages: List[str], model: str, max_tokens:
return assistant_responses[-1]
def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
inst: Dict[str, Any],
metrics: List[str],
language: str,
reference: Dict[str, Any] = None,
model: str = "gpt-3.5-turbo",
max_tokens: int = 2048) -> Dict[str, Any]:
def get_gpt_evaluation_without_logprobs(
prompt: Dict[str, Any],
inst: Dict[str, Any],
metrics: List[str],
language: str,
reference: Dict[str, Any] = None,
model: str = "gpt-3.5-turbo",
max_tokens: int = 2048,
) -> Dict[str, Any]:
"""
Use chat models(gpt-3.5-turbo or gpt-4) to evaluate one model answer.
@ -378,7 +381,7 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
MAX_API_RETRY = 3
question = (inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"])
question = inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"]
answer = inst["output"]
inst["evaluation"] = {}
@ -400,10 +403,9 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
if prompt_reference:
# Do a 2-round conversation
response = multiturn_chat_completion([prompt_1st_round, prompt_reference],
model,
max_tokens=max_tokens,
turns=2)
response = multiturn_chat_completion(
[prompt_1st_round, prompt_reference], model, max_tokens=max_tokens, turns=2
)
else:
response = multiturn_chat_completion([prompt_1st_round], model, max_tokens=max_tokens, turns=1)
@ -427,10 +429,9 @@ def get_gpt_evaluation_without_logprobs(prompt: Dict[str, Any],
return inst
def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any],
inst: Dict[str, Any],
metrics: List[str],
max_tokens: int = 2048) -> Dict[str, Any]:
def get_gpt_evaluation_with_logprobs(
prompt: Dict[str, Any], inst: Dict[str, Any], metrics: List[str], max_tokens: int = 2048
) -> Dict[str, Any]:
"""
Use completion model(text-davinci-003) to evaluate one model answer.
Only completion models can return log probabilities.
@ -449,7 +450,7 @@ def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any],
MAX_API_RETRY = 3
question = (inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"])
question = inst["instruction"] if inst["input"] == "" else inst["instruction"] + "\n" + inst["input"]
answer = inst["output"]
inst["evaluation"] = {}
@ -492,13 +493,15 @@ def get_gpt_evaluation_with_logprobs(prompt: Dict[str, Any],
return inst
def evaluate(answers: List[Dict],
prompt: Dict[str, Any],
metrics: List[str],
category: str,
model: str,
language: str,
references: List[Dict] = None) -> List[Dict]:
def evaluate(
answers: List[Dict],
prompt: Dict[str, Any],
metrics: List[str],
category: str,
model: str,
language: str,
references: List[Dict] = None,
) -> List[Dict]:
"""
Use GPT models to evaluate model answers and save evaluation results.
@ -529,21 +532,23 @@ def evaluate(answers: List[Dict],
if model == "text-davinci-003":
future = executor.submit(get_gpt_evaluation_with_logprobs, prompt, inst, metrics, 1)
else:
future = executor.submit(get_gpt_evaluation_without_logprobs,
prompt,
inst,
metrics,
language,
reference=None if references is None else references[idx],
model=model,
max_tokens=1)
future = executor.submit(
get_gpt_evaluation_without_logprobs,
prompt,
inst,
metrics,
language,
reference=None if references is None else references[idx],
model=model,
max_tokens=1,
)
futures.append(future)
for future in tqdm.tqdm(
concurrent.futures.as_completed(futures),
desc=f"{category}: ",
total=len(futures),
concurrent.futures.as_completed(futures),
desc=f"{category}: ",
total=len(futures),
):
evaluations.append(future.result())
@ -610,12 +615,13 @@ def calculate_scores_form_response(response: str, evaluation: Dict[str, Any]) ->
return int(results[0])
else:
raise Exception(f"Invalid score pair. Got {evaluation}.")
except Exception as e:
except Exception:
return 0
def save_gpt_evaluation_results(model_name: str, gpt_evaluation_results: Dict[str, Any],
save_path: str) -> Dict[str, Any]:
def save_gpt_evaluation_results(
model_name: str, gpt_evaluation_results: Dict[str, Any], save_path: str
) -> Dict[str, Any]:
"""
Save evaluation results for different categories for one model.
@ -667,10 +673,12 @@ def save_gpt_evaluation_statistics(model_name: str, evaluations: List[Dict], sav
scores[metric].append(0)
elif evaluation["evaluation"][metric]["logprobs"] is not None:
scores[metric].append(
calculate_scores_form_logprobs(evaluation["evaluation"][metric]["logprobs"][0]))
calculate_scores_form_logprobs(evaluation["evaluation"][metric]["logprobs"][0])
)
else:
scores[metric].append(
calculate_scores_form_response(evaluation["evaluation"][metric]["response"], evaluation))
calculate_scores_form_response(evaluation["evaluation"][metric]["response"], evaluation)
)
statistics = {}
for metric in metrics:
@ -751,9 +759,9 @@ def analyze_gpt_evaluation_statistics(statistics_path: str, save_path: str) -> N
frame_all.to_csv(os.path.join(save_path, "gpt_evaluation_statistics.csv"))
for category in tqdm.tqdm(
frame_per_category.keys(),
desc=f"GPT evaluation: ",
total=len(frame_per_category.keys()),
frame_per_category.keys(),
desc=f"GPT evaluation: ",
total=len(frame_per_category.keys()),
):
data = pd.DataFrame(frame_per_category[category])

View File

@ -21,13 +21,17 @@ def bleu_score(preds: List[str], targets: List[str], language: str) -> Dict[str,
"""
bleu_scores = {"bleu1": 0, "bleu2": 0, "bleu3": 0, "bleu4": 0}
cumulative_bleu = [0] * 4
weights = [(1. / 1., 0., 0., 0.), (1. / 2., 1. / 2., 0., 0.), (1. / 3., 1. / 3., 1. / 3., 0.),
(1. / 4., 1. / 4., 1. / 4., 1. / 4.)]
weights = [
(1.0 / 1.0, 0.0, 0.0, 0.0),
(1.0 / 2.0, 1.0 / 2.0, 0.0, 0.0),
(1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0, 0.0),
(1.0 / 4.0, 1.0 / 4.0, 1.0 / 4.0, 1.0 / 4.0),
]
for pred, target in zip(preds, targets):
if language == "cn":
pred_list = ' '.join(jieba.cut(preprocessing_text(pred))).split()
target_list = [(' '.join(jieba.cut(preprocessing_text(target)))).split()]
pred_list = " ".join(jieba.cut(preprocessing_text(pred))).split()
target_list = [(" ".join(jieba.cut(preprocessing_text(target)))).split()]
elif language == "en":
pred_list = preprocessing_text(pred).split()
target_list = [preprocessing_text(target).split()]
@ -42,15 +46,14 @@ def bleu_score(preds: List[str], targets: List[str], language: str) -> Dict[str,
def chrf_score(preds: List[str], targets: List[str], language: str) -> Dict[str, float]:
"""Calculate CHRF Score Metric in sentence level.
"""
"""Calculate CHRF Score Metric in sentence level."""
chrf_score = {"chrf": 0}
cumulative_chrf = []
for pred, target in zip(preds, targets):
if language == "cn":
pred_list = ' '.join(jieba.cut(preprocessing_text(pred))).split()
target_list = ' '.join(jieba.cut(preprocessing_text(target))).split()
pred_list = " ".join(jieba.cut(preprocessing_text(pred))).split()
target_list = " ".join(jieba.cut(preprocessing_text(target))).split()
elif language == "en":
pred_list = preprocessing_text(pred).split()
target_list = preprocessing_text(target).split()
@ -75,8 +78,8 @@ def rouge_cn_score(preds: List[str], targets: List[str]) -> Dict[str, float]:
all_targets = []
for pred, target in zip(preds, targets):
pred_list = remove_redundant_space(' '.join(jieba.cut(preprocessing_text(pred))))
target_list = remove_redundant_space(' '.join(jieba.cut(preprocessing_text(target))))
pred_list = remove_redundant_space(" ".join(jieba.cut(preprocessing_text(pred))))
target_list = remove_redundant_space(" ".join(jieba.cut(preprocessing_text(target))))
all_preds.append(pred_list)
all_targets.append(target_list)
@ -99,16 +102,14 @@ def rouge_en_score(preds: List[str], targets: List[str]) -> Dict[str, float]:
longest common subsequence (LCS) between preds and targets.
"""
rouge_scores = {"rouge1": 0, "rouge2": 0, "rougeL": 0}
all_preds = []
all_targets = []
rouge_en = Rouge_en.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=False)
for pred, target in zip(preds, targets):
score = rouge_en.score(preprocessing_text(pred), preprocessing_text(target))
rouge_scores["rouge1"] += score['rouge1'].fmeasure
rouge_scores["rouge2"] += score['rouge2'].fmeasure
rouge_scores["rougeL"] += score['rougeL'].fmeasure
rouge_scores["rouge1"] += score["rouge1"].fmeasure
rouge_scores["rouge2"] += score["rouge2"].fmeasure
rouge_scores["rougeL"] += score["rougeL"].fmeasure
rouge_scores["rouge1"] = rouge_scores["rouge1"] / len(preds)
rouge_scores["rouge2"] = rouge_scores["rouge2"] / len(preds)
@ -137,7 +138,7 @@ def distinct_score(preds: List[str], language: str) -> Dict[str, float]:
for pred in preds:
if language == "cn":
pred_seg_list = ' '.join(jieba.cut(pred)).split()
pred_seg_list = " ".join(jieba.cut(pred)).split()
count_segs = len(pred_seg_list)
unique_segs = set(pred_seg_list)
count_unique_chars = len(unique_segs)
@ -151,7 +152,7 @@ def distinct_score(preds: List[str], language: str) -> Dict[str, float]:
split_pred = preprocessing_text(pred).split()
for n in range(0, 3):
for i in range(0, len(split_pred) - n):
ngram = ' '.join(split_pred[i:i + n + 1])
ngram = " ".join(split_pred[i : i + n + 1])
unique_ngram[n].add(ngram)
all_ngram_count[n] += 1
@ -203,8 +204,8 @@ def calculate_precision_recall_f1(preds: List[str], targets: List[str], language
for pred, target in zip(preds, targets):
if language == "cn":
pred_list = [char for char in ' '.join(jieba.cut(preprocessing_text(pred))).split()]
target_list = [char for char in ' '.join(jieba.cut(preprocessing_text(target))).split()]
pred_list = [char for char in " ".join(jieba.cut(preprocessing_text(pred))).split()]
target_list = [char for char in " ".join(jieba.cut(preprocessing_text(target))).split()]
elif language == "en":
pred_list = [char for char in preprocessing_text(pred).split()]
target_list = [char for char in preprocessing_text(target).split()]

View File

@ -7,6 +7,9 @@ from .utils import (
)
__all__ = [
'get_evaluator', 'convert_data_to_unieval_format', 'calculate_average_score', 'save_unieval_results',
'analyze_unieval_results'
"get_evaluator",
"convert_data_to_unieval_format",
"calculate_average_score",
"save_unieval_results",
"analyze_unieval_results",
]

View File

@ -28,29 +28,29 @@ from .utils import add_question
class SumEvaluator:
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
""" Set up evaluator for text summarization """
def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
"""Set up evaluator for text summarization"""
self.scorer = UniEvaluator(
model_name_or_path='MingZhong/unieval-sum' if model_name_or_path == "" else model_name_or_path,
model_name_or_path="MingZhong/unieval-sum" if model_name_or_path == "" else model_name_or_path,
max_length=max_length,
device=device,
cache_dir=cache_dir)
self.task = 'summarization'
self.dimensions = ['coherence', 'consistency', 'fluency', 'relevance']
cache_dir=cache_dir,
)
self.task = "summarization"
self.dimensions = ["coherence", "consistency", "fluency", "relevance"]
def evaluate(self, data, category, dims=None, overall=True):
"""
Get the scores of all the given dimensions
Get the scores of all the given dimensions
category: The category to be evaluated.
category: The category to be evaluated.
dims: A list of dimensions to be evaluated. If dims is None, SumEvaluator will evaluate
four dimensions: coherence, consistency, fluency, relevance.
dims: A list of dimensions to be evaluated. If dims is None, SumEvaluator will evaluate
four dimensions: coherence, consistency, fluency, relevance.
overall: indicates whether the overall score is to be calculated.
Overall score can be customized to a combination of scores based on different
dimensions. The default here is the average score of all the given dimensions.
overall: indicates whether the overall score is to be calculated.
Overall score can be customized to a combination of scores based on different
dimensions. The default here is the average score of all the given dimensions.
"""
n_data = len(data)
eval_scores = [{} for _ in range(n_data)]
@ -63,12 +63,12 @@ class SumEvaluator:
for dim in eval_dims:
# Calculate average sentence-level scores for 'consistency' and 'fluency'
if dim == 'consistency' or dim == 'fluency':
if dim == "consistency" or dim == "fluency":
src_list, output_list = [], []
n_sents = [] # the number of sentences in each generated summary
n_sents = [] # the number of sentences in each generated summary
for i in range(n_data):
source = data[i]['source']
system_outputs = sent_tokenize(data[i]['system_output'])
source = data[i]["source"]
system_outputs = sent_tokenize(data[i]["system_output"])
n_sents.append(len(system_outputs))
for j in range(len(system_outputs)):
src_list.append(source)
@ -81,24 +81,26 @@ class SumEvaluator:
score = []
for cur_n_sent in n_sents:
# prevent denominator from being 0
score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]) / (cur_n_sent + 1e-6))
score.append(sum(sent_score[start_idx : start_idx + cur_n_sent]) / (cur_n_sent + 1e-6))
start_idx += cur_n_sent
# Calculate summary-level score for 'coherence' and 'relevance'
elif dim == 'coherence' or dim == 'relevance':
elif dim == "coherence" or dim == "relevance":
src_list, output_list, ref_list = [], [], []
for i in range(n_data):
src_list.append(data[i]['source'])
output_list.append(data[i]['system_output'])
if dim == 'relevance':
ref_list.append(data[i]['reference'])
src_list.append(data[i]["source"])
output_list.append(data[i]["system_output"])
if dim == "relevance":
ref_list.append(data[i]["reference"])
input_list = add_question(dimension=dim, output=output_list, src=src_list, ref=ref_list, task=self.task)
score = self.scorer.score(input_list, self.task, category, dim)
# Please customize other dimensions here for summarization
else:
raise NotImplementedError('The input format for this dimension is still undefined. \
Please customize it first.')
raise NotImplementedError(
"The input format for this dimension is still undefined. \
Please customize it first."
)
for i in range(n_data):
eval_scores[i][dim] = score[i]
@ -106,35 +108,35 @@ class SumEvaluator:
# Customize your overall score here.
if overall == True:
for i in range(n_data):
eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values()))
eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values()))
return eval_scores
class DialogEvaluator:
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
""" Set up evaluator for dialogues """
def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
"""Set up evaluator for dialogues"""
self.scorer = UniEvaluator(
model_name_or_path='MingZhong/unieval-dialog' if model_name_or_path == "" else model_name_or_path,
model_name_or_path="MingZhong/unieval-dialog" if model_name_or_path == "" else model_name_or_path,
max_length=max_length,
device=device,
cache_dir=cache_dir)
self.task = 'dialogue'
self.dimensions = ['naturalness', 'coherence', 'engagingness', 'groundedness', 'understandability']
cache_dir=cache_dir,
)
self.task = "dialogue"
self.dimensions = ["naturalness", "coherence", "engagingness", "groundedness", "understandability"]
def evaluate(self, data, category, dims=None, overall=True):
"""
Get the scores of all the given dimensions
Get the scores of all the given dimensions
category: The category to be evaluated.
category: The category to be evaluated.
dims: A list of dimensions to be evaluated. If dims is None, DialogEvaluator will evaluate
five dimensions: naturalness, coherence, engagingness, groundedness and understandability.
dims: A list of dimensions to be evaluated. If dims is None, DialogEvaluator will evaluate
five dimensions: naturalness, coherence, engagingness, groundedness and understandability.
overall: indicates whether the overall score is to be calculated.
Overall score can be customized to a combination of scores based on different
dimensions. The default here is the average score of all the given dimensions.
overall: indicates whether the overall score is to be calculated.
Overall score can be customized to a combination of scores based on different
dimensions. The default here is the average score of all the given dimensions.
"""
n_data = len(data)
eval_scores = [{} for _ in range(n_data)]
@ -147,50 +149,48 @@ class DialogEvaluator:
for dim in eval_dims:
# Calculate summation score for 'engagingness'
if dim == 'engagingness':
if dim == "engagingness":
src_list, output_list, context_list = [], [], []
n_sents = [] # the number of sentences in each generated response
n_sents = [] # the number of sentences in each generated response
for i in range(n_data):
source = data[i]['source']
context = data[i]['context']
system_outputs = sent_tokenize(data[i]['system_output'])
source = data[i]["source"]
context = data[i]["context"]
system_outputs = sent_tokenize(data[i]["system_output"])
n_sents.append(len(system_outputs))
for j in range(len(system_outputs)):
src_list.append(source)
context_list.append(context)
output_list.append(system_outputs[j])
input_list = add_question(dimension=dim,
output=output_list,
src=src_list,
context=context_list,
task=self.task)
input_list = add_question(
dimension=dim, output=output_list, src=src_list, context=context_list, task=self.task
)
sent_score = self.scorer.score(input_list, self.task, category, dim)
# Get the summation score for each sample
start_idx = 0
score = []
for cur_n_sent in n_sents:
score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]))
score.append(sum(sent_score[start_idx : start_idx + cur_n_sent]))
start_idx += cur_n_sent
# Calculate turn-level score for other dimensions
elif dim in ['naturalness', 'coherence', 'groundedness', 'understandability']:
elif dim in ["naturalness", "coherence", "groundedness", "understandability"]:
src_list, output_list, context_list = [], [], []
for i in range(n_data):
src_list.append(data[i]['source'])
output_list.append(data[i]['system_output'])
context_list.append(data[i]['context'])
input_list = add_question(dimension=dim,
output=output_list,
src=src_list,
context=context_list,
task=self.task)
src_list.append(data[i]["source"])
output_list.append(data[i]["system_output"])
context_list.append(data[i]["context"])
input_list = add_question(
dimension=dim, output=output_list, src=src_list, context=context_list, task=self.task
)
score = self.scorer.score(input_list, self.task, category, dim)
# Please customize other dimensions here for summarization
else:
raise NotImplementedError('The input format for this dimension is still undefined. \
Please customize it first.')
raise NotImplementedError(
"The input format for this dimension is still undefined. \
Please customize it first."
)
for i in range(n_data):
eval_scores[i][dim] = score[i]
@ -198,35 +198,35 @@ class DialogEvaluator:
# Customize your overall score here.
if overall == True:
for i in range(n_data):
eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values()))
eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values()))
return eval_scores
class D2tEvaluator:
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
""" Set up evaluator for data-to-text """
def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
"""Set up evaluator for data-to-text"""
self.scorer = UniEvaluator(
model_name_or_path='MingZhong/unieval-sum' if model_name_or_path == "" else model_name_or_path,
model_name_or_path="MingZhong/unieval-sum" if model_name_or_path == "" else model_name_or_path,
max_length=max_length,
device=device,
cache_dir=cache_dir)
self.task = 'data2text'
self.dimensions = ['naturalness', 'informativeness']
cache_dir=cache_dir,
)
self.task = "data2text"
self.dimensions = ["naturalness", "informativeness"]
def evaluate(self, data, category, dims=None, overall=True):
"""
Get the scores of all the given dimensions
Get the scores of all the given dimensions
category: The category to be evaluated.
category: The category to be evaluated.
dims: A list of dimensions to be evaluated. If dims is None, D2tEvaluator will evaluate
two dimensions: naturalness and informativeness.
dims: A list of dimensions to be evaluated. If dims is None, D2tEvaluator will evaluate
two dimensions: naturalness and informativeness.
overall: indicates whether the overall score is to be calculated.
Overall score can be customized to a combination of scores based on different
dimensions. The default here is the average score of all the given dimensions.
overall: indicates whether the overall score is to be calculated.
Overall score can be customized to a combination of scores based on different
dimensions. The default here is the average score of all the given dimensions.
"""
n_data = len(data)
eval_scores = [{} for _ in range(n_data)]
@ -240,8 +240,8 @@ class D2tEvaluator:
for dim in eval_dims:
output_list, ref_list = [], []
for i in range(n_data):
output_list.append(data[i]['system_output'])
ref_list.append(data[i]['reference'])
output_list.append(data[i]["system_output"])
ref_list.append(data[i]["reference"])
input_list = add_question(dimension=dim, output=output_list, ref=ref_list, task=self.task)
score = self.scorer.score(input_list, self.task, category, dim)
@ -252,38 +252,38 @@ class D2tEvaluator:
# Customize your overall score here.
if overall == True:
for i in range(n_data):
eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values()))
eval_scores[i]["overall"] = np.mean(list(eval_scores[i].values()))
return eval_scores
class FactEvaluator:
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
""" Set up evaluator for factual consistency detection """
def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
"""Set up evaluator for factual consistency detection"""
self.scorer = UniEvaluator(
model_name_or_path='MingZhong/unieval-fact' if model_name_or_path == "" else model_name_or_path,
model_name_or_path="MingZhong/unieval-fact" if model_name_or_path == "" else model_name_or_path,
max_length=max_length,
device=device,
cache_dir=cache_dir)
self.task = 'fact'
self.dim = 'consistency'
cache_dir=cache_dir,
)
self.task = "fact"
self.dim = "consistency"
def evaluate(self, data, category):
"""
Get the factual consistency score (only 1 dimension for this task)
Get the factual consistency score (only 1 dimension for this task)
category: The category to be evaluated.
category: The category to be evaluated.
"""
n_data = len(data)
eval_scores = [{} for _ in range(n_data)]
# Calculate average sentence-level scores for factual consistency
src_list, output_list = [], []
n_sents = [] # the number of sentences in the claim
n_sents = [] # the number of sentences in the claim
for i in range(n_data):
source = data[i]['source']
system_outputs = sent_tokenize(data[i]['system_output'])
source = data[i]["source"]
system_outputs = sent_tokenize(data[i]["system_output"])
n_sents.append(len(system_outputs))
for j in range(len(system_outputs)):
src_list.append(source)
@ -295,7 +295,7 @@ class FactEvaluator:
start_idx = 0
score = []
for cur_n_sent in n_sents:
score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]) / cur_n_sent)
score.append(sum(sent_score[start_idx : start_idx + cur_n_sent]) / cur_n_sent)
start_idx += cur_n_sent
for i in range(n_data):
@ -304,28 +304,26 @@ class FactEvaluator:
return eval_scores
def get_evaluator(task, model_name_or_path="", max_length=1024, device='cuda:0', cache_dir=None):
assert task in ['summarization', 'dialogue', 'data2text', 'fact']
if task == 'summarization':
return SumEvaluator(model_name_or_path=model_name_or_path,
max_length=max_length,
device=device,
cache_dir=cache_dir)
elif task == 'dialogue':
return DialogEvaluator(model_name_or_path=model_name_or_path,
max_length=max_length,
device=device,
cache_dir=cache_dir)
elif task == 'data2text':
return D2tEvaluator(model_name_or_path=model_name_or_path,
max_length=max_length,
device=device,
cache_dir=cache_dir)
elif task == 'fact':
return FactEvaluator(model_name_or_path=model_name_or_path,
max_length=max_length,
device=device,
cache_dir=cache_dir)
def get_evaluator(task, model_name_or_path="", max_length=1024, device="cuda:0", cache_dir=None):
assert task in ["summarization", "dialogue", "data2text", "fact"]
if task == "summarization":
return SumEvaluator(
model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir
)
elif task == "dialogue":
return DialogEvaluator(
model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir
)
elif task == "data2text":
return D2tEvaluator(
model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir
)
elif task == "fact":
return FactEvaluator(
model_name_or_path=model_name_or_path, max_length=max_length, device=device, cache_dir=cache_dir
)
else:
raise NotImplementedError('Other tasks are not implemented, \
please customize specific tasks here.')
raise NotImplementedError(
"Other tasks are not implemented, \
please customize specific tasks here."
)

View File

@ -27,9 +27,8 @@ from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
class UniEvaluator:
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
""" Set up model """
def __init__(self, model_name_or_path, max_length=1024, device="cuda:0", cache_dir=None):
"""Set up model"""
self.device = device
self.max_length = max_length
@ -47,8 +46,8 @@ class UniEvaluator:
def score(self, inputs, task, category, dim, batch_size=8):
"""
Get scores for the given samples.
final_score = postive_score / (postive_score + negative_score)
Get scores for the given samples.
final_score = postive_score / (postive_score + negative_score)
"""
# The implementation of "forward" in T5 still requires decoder_input_ids.
@ -58,31 +57,27 @@ class UniEvaluator:
pos_score_list, neg_score_list = [], []
for i in tqdm(range(0, len(inputs), batch_size), desc=f"{category}-({dim}-{task}): "):
src_list = inputs[i:i + batch_size]
tgt_list = tgts[i:i + batch_size]
src_list = inputs[i : i + batch_size]
tgt_list = tgts[i : i + batch_size]
try:
with torch.no_grad():
encoded_src = self.tokenizer(src_list,
max_length=self.max_length,
truncation=True,
padding=True,
return_tensors='pt')
encoded_tgt = self.tokenizer(tgt_list,
max_length=self.max_length,
truncation=True,
padding=True,
return_tensors='pt')
encoded_src = self.tokenizer(
src_list, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt"
)
encoded_tgt = self.tokenizer(
tgt_list, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt"
)
src_tokens = encoded_src['input_ids'].to(self.device)
src_mask = encoded_src['attention_mask'].to(self.device)
src_tokens = encoded_src["input_ids"].to(self.device)
src_mask = encoded_src["attention_mask"].to(self.device)
tgt_tokens = encoded_tgt['input_ids'].to(self.device)[:, 0].unsqueeze(-1)
tgt_tokens = encoded_tgt["input_ids"].to(self.device)[:, 0].unsqueeze(-1)
output = self.model(input_ids=src_tokens, attention_mask=src_mask, labels=tgt_tokens)
logits = output.logits.view(-1, self.model.config.vocab_size)
pos_score = self.softmax(logits)[:, self.pos_id] # Yes
neg_score = self.softmax(logits)[:, self.neg_id] # No
pos_score = self.softmax(logits)[:, self.pos_id] # Yes
neg_score = self.softmax(logits)[:, self.neg_id] # No
cur_pos_score = [x.item() for x in pos_score]
cur_neg_score = [x.item() for x in neg_score]
@ -90,8 +85,8 @@ class UniEvaluator:
neg_score_list += cur_neg_score
except RuntimeError:
print(f'source: {src_list}')
print(f'target: {tgt_list}')
print(f"source: {src_list}")
print(f"target: {tgt_list}")
exit(0)
score_list = []

View File

@ -31,105 +31,142 @@ import tqdm
def add_question(dimension, output, src=None, ref=None, context=None, task=None):
"""
Add questions to generate input in Bool-QA format for UniEval.
Add questions to generate input in Bool-QA format for UniEval.
dimension: specific dimension to be evaluated
src: source input for different NLG tasks. For example, source document for summarization
and dialogue history for dialogue response generation.
output: output text generated by the models
ref: human-annotated groundtruth
context: the context needed to evaluate several specific dimension. For example,
additional factual information when evaluating engagingness and groundedness in dialogues.
dimension: specific dimension to be evaluated
src: source input for different NLG tasks. For example, source document for summarization
and dialogue history for dialogue response generation.
output: output text generated by the models
ref: human-annotated groundtruth
context: the context needed to evaluate several specific dimension. For example,
additional factual information when evaluating engagingness and groundedness in dialogues.
"""
input_with_question = []
for i in range(len(output)):
# For summarization
if task == 'summarization':
if dimension == 'fluency':
cur_input = 'question: Is this a fluent paragraph? </s> paragraph: ' + output[i]
elif dimension == 'coherence':
cur_input = 'question: Is this a coherent summary to the document? </s> summary: ' + output[
i] + ' </s> document: ' + src[i]
elif dimension == 'consistency':
cur_input = 'question: Is this claim consistent with the document? </s> claim: ' + output[
i] + ' </s> document: ' + src[i]
elif dimension == 'relevance':
cur_input = 'question: Is this summary relevant to the reference? </s> summary: ' + output[
i] + ' </s> reference: ' + ref[i]
if task == "summarization":
if dimension == "fluency":
cur_input = "question: Is this a fluent paragraph? </s> paragraph: " + output[i]
elif dimension == "coherence":
cur_input = (
"question: Is this a coherent summary to the document? </s> summary: "
+ output[i]
+ " </s> document: "
+ src[i]
)
elif dimension == "consistency":
cur_input = (
"question: Is this claim consistent with the document? </s> claim: "
+ output[i]
+ " </s> document: "
+ src[i]
)
elif dimension == "relevance":
cur_input = (
"question: Is this summary relevant to the reference? </s> summary: "
+ output[i]
+ " </s> reference: "
+ ref[i]
)
else:
raise NotImplementedError(
'The input format for this dimension is still undefined. Please customize it first.')
"The input format for this dimension is still undefined. Please customize it first."
)
# For dialogues
elif task == 'dialogue':
if dimension == 'naturalness':
cur_input = 'question: Is this a natural response in the dialogue? </s> response: ' + output[i]
elif dimension == 'coherence':
cur_input = 'question: Is this a coherent response given the dialogue history? </s> response: '\
+ output[i] + ' </s> dialogue history: ' + src[i]
elif dimension == 'engagingness':
cur_input = 'question: Is this an engaging and informative response according to the dialogue history and fact? </s> response: '\
+ output[i] + ' </s> dialogue history: ' + src[i] + ' </s> fact: ' + context[i]
elif dimension == 'groundedness':
cur_input = 'question: Is this response consistent with knowledge in the fact? </s> response: '\
+ output[i] + ' </s> fact: ' + context[i]
elif dimension == 'understandability':
cur_input = 'question: Is this an understandable response in the dialogue? </s> response: ' + output[i]
elif task == "dialogue":
if dimension == "naturalness":
cur_input = "question: Is this a natural response in the dialogue? </s> response: " + output[i]
elif dimension == "coherence":
cur_input = (
"question: Is this a coherent response given the dialogue history? </s> response: "
+ output[i]
+ " </s> dialogue history: "
+ src[i]
)
elif dimension == "engagingness":
cur_input = (
"question: Is this an engaging and informative response according to the dialogue history and fact? </s> response: "
+ output[i]
+ " </s> dialogue history: "
+ src[i]
+ " </s> fact: "
+ context[i]
)
elif dimension == "groundedness":
cur_input = (
"question: Is this response consistent with knowledge in the fact? </s> response: "
+ output[i]
+ " </s> fact: "
+ context[i]
)
elif dimension == "understandability":
cur_input = "question: Is this an understandable response in the dialogue? </s> response: " + output[i]
else:
raise NotImplementedError(
'The input format for this dimension is still undefined. Please customize it first.')
"The input format for this dimension is still undefined. Please customize it first."
)
# For data-to-text
elif task == 'data2text':
if dimension == 'naturalness':
cur_input = 'question: Is this a fluent utterance? </s> utterance: ' + output[i]
elif dimension == 'informativeness':
cur_input = 'question: Is this sentence informative according to the reference? </s> sentence: '\
+ output[i] + ' </s> reference: ' + ref[i]
elif task == "data2text":
if dimension == "naturalness":
cur_input = "question: Is this a fluent utterance? </s> utterance: " + output[i]
elif dimension == "informativeness":
cur_input = (
"question: Is this sentence informative according to the reference? </s> sentence: "
+ output[i]
+ " </s> reference: "
+ ref[i]
)
else:
raise NotImplementedError(
'The input format for this dimension is still undefined. Please customize it first.')
"The input format for this dimension is still undefined. Please customize it first."
)
# For factual consistency detection
elif task == 'fact':
if dimension == 'consistency':
cur_input = 'question: Is this claim consistent with the document? </s> claim: ' + output[
i] + ' </s> document: ' + src[i]
elif task == "fact":
if dimension == "consistency":
cur_input = (
"question: Is this claim consistent with the document? </s> claim: "
+ output[i]
+ " </s> document: "
+ src[i]
)
else:
raise NotImplementedError('No other dimensions for the factual consistency detection task.')
raise NotImplementedError("No other dimensions for the factual consistency detection task.")
# For new customized tasks
else:
raise NotImplementedError('Other tasks are not implemented, please customize specific tasks here.')
raise NotImplementedError("Other tasks are not implemented, please customize specific tasks here.")
input_with_question.append(cur_input)
return input_with_question
def convert_data_to_unieval_format(output_list, src_list=None, ref_list=None):
"""
Convert the data into the unieval's format.
Convert the data into the unieval's format.
output_list: a list of model output
output_list: a list of model output
src_list: source input for different NLG tasks. For example, source document for summarization
and dialogue history for dialogue response generation
ref_list: human-annotated groundtruth
src_list: source input for different NLG tasks. For example, source document for summarization
and dialogue history for dialogue response generation
ref_list: human-annotated groundtruth
"""
json_data = []
for i in range(len(output_list)):
cur = {}
cur['system_output'] = output_list[i]
cur["system_output"] = output_list[i]
if src_list is not None:
cur['source'] = src_list[i]
cur["source"] = src_list[i]
if ref_list is not None:
cur['reference'] = ref_list[i]
cur['context'] = ""
cur["reference"] = ref_list[i]
cur["context"] = ""
json_data.append(cur)
return json_data
def calculate_average_score(scores):
"""
Calculate average scores for different metrics
Calculate average scores for different metrics
scores: a list of scores for different metrics for each answer
scores: a list of scores for different metrics for each answer
"""
metrics = {metric: 0 for metric in scores[0]}
@ -226,9 +263,9 @@ def analyze_unieval_results(results_path: str, save_path: str) -> None:
frame_all.to_csv(os.path.join(save_path, "unieval_statistics.csv"))
for metric in tqdm.tqdm(
frame_per_metric.keys(),
desc=f"UniEval metrics: ",
total=len(frame_per_metric.keys()),
frame_per_metric.keys(),
desc=f"UniEval metrics: ",
total=len(frame_per_metric.keys()),
):
data = pd.DataFrame(frame_per_metric[metric])

View File

@ -1,7 +1,6 @@
import io
import json
import os
import re
import string
from typing import Dict
@ -55,7 +54,7 @@ def jload(f, mode="r"):
def get_json_list(file_path):
with open(file_path, 'r') as f:
with open(file_path, "r") as f:
json_list = []
for line in f:
json_list.append(json.loads(line))
@ -187,9 +186,9 @@ def analyze_automatic_results(results_path: str, save_path: str) -> None:
frame_all.to_csv(os.path.join(save_path, "automatic_evaluation_statistics.csv"))
for metric in tqdm.tqdm(
frame_per_metric.keys(),
desc=f"automatic metrics: ",
total=len(frame_per_metric.keys()),
frame_per_metric.keys(),
desc=f"automatic metrics: ",
total=len(frame_per_metric.keys()),
):
data = pd.DataFrame(frame_per_metric[metric])

View File

@ -3,7 +3,6 @@ import json
from typing import Dict, Sequence
import torch
from datasets import load_dataset
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
@ -20,7 +19,8 @@ def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: i
padding="longest",
max_length=max_length,
truncation=True,
) for text in strings
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
@ -48,18 +48,17 @@ def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTo
class EasySupervisedDataset(Dataset):
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None:
super(EasySupervisedDataset, self).__init__()
with open(data_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
#split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
# split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
sources, targets = [], []
for line in all_lines:
if "回答:" in line:
sep_index = line.index("回答:")
sources.append(line[:sep_index + 3])
targets.append(line[sep_index + 3:] + tokenizer.eos_token)
sources.append(line[: sep_index + 3])
targets.append(line[sep_index + 3 :] + tokenizer.eos_token)
else:
sources.append(line)
targets.append("" + tokenizer.eos_token)
@ -83,15 +82,17 @@ class EasySupervisedDataset(Dataset):
class EasyPromptsDataset(Dataset):
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None:
super(EasyPromptsDataset, self).__init__()
with open(data_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
all_lines = [line if "回答:" not in line else line[:line.index("回答:") + 3] for line in all_lines]
all_lines = [line if "回答:" not in line else line[: line.index("回答:") + 3] for line in all_lines]
self.prompts = [
tokenizer(line, return_tensors='pt', max_length=max_length, padding='max_length',
truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0)
tokenizer(line, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True)[
"input_ids"
]
.to(torch.cuda.current_device())
.squeeze(0)
for line in tqdm(all_lines)
]
self.data_file = data_file
@ -110,7 +111,6 @@ class EasyPromptsDataset(Dataset):
class EasyRewardDataset(Dataset):
def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None:
super(EasyRewardDataset, self).__init__()
self.chosen = []
@ -120,44 +120,42 @@ class EasyRewardDataset(Dataset):
else:
self.end_token = special_token
print(self.end_token)
#read all lines in the train_file to a list
# read all lines in the train_file to a list
with open(train_file, "r", encoding="UTF-8") as f:
all_lines = f.readlines()
for line in tqdm(all_lines):
data = json.loads(line)
prompt = "提问:" + data['prompt'] + " 回答:"
prompt = "提问:" + data["prompt"] + " 回答:"
chosen = prompt + data['chosen'] + self.end_token
chosen_token = tokenizer(chosen,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.chosen.append({
"input_ids": chosen_token['input_ids'],
"attention_mask": chosen_token['attention_mask']
})
chosen = prompt + data["chosen"] + self.end_token
chosen_token = tokenizer(
chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
self.chosen.append(
{"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
)
reject = prompt + data['rejected'] + self.end_token
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject.append({
"input_ids": reject_token['input_ids'],
"attention_mask": reject_token['attention_mask']
})
reject = prompt + data["rejected"] + self.end_token
reject_token = tokenizer(
reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)
self.reject.append(
{"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
)
def __len__(self):
length = len(self.chosen)
return length
def __getitem__(self, idx):
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
"input_ids"], self.reject[idx]["attention_mask"]
return (
self.chosen[idx]["input_ids"],
self.chosen[idx]["attention_mask"],
self.reject[idx]["input_ids"],
self.reject[idx]["attention_mask"],
)
#python representation of the object and the string representation of the object
# python representation of the object and the string representation of the object
def __repr__(self):
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
@ -165,26 +163,25 @@ class EasyRewardDataset(Dataset):
return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})"
'''
"""
Easy SFT just accept a text file which can be read line by line. However the datasets will group texts together to max_length so LLM will learn the texts meaning better.
If individual lines are not related, just set is_group_texts to False.
'''
"""
class EasySFTDataset(Dataset):
def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None:
super().__init__()
#read the data_file line by line
# read the data_file line by line
with open(data_file, "r", encoding="UTF-8") as f:
#encode the text data line by line and put raw python list input_ids only to raw_input_ids list
# encode the text data line by line and put raw python list input_ids only to raw_input_ids list
raw_input_ids = []
for line in f:
encoded_ids = tokenizer.encode(line)
#if the encoded_ids is longer than max_length, then split it into several parts
# if the encoded_ids is longer than max_length, then split it into several parts
if len(encoded_ids) > max_length:
for i in range(0, len(encoded_ids), max_length):
raw_input_ids.append(encoded_ids[i:i + max_length])
raw_input_ids.append(encoded_ids[i : i + max_length])
else:
raw_input_ids.append(encoded_ids)
@ -196,12 +193,13 @@ class EasySFTDataset(Dataset):
if is_group_texts:
for input_ids in raw_input_ids:
if len(current_input_ids) + len(input_ids) > max_length:
#pad the current_input_ids to max_length with tokenizer.pad_token_id
# pad the current_input_ids to max_length with tokenizer.pad_token_id
padded_length = max_length - len(current_input_ids)
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
attention_mask.append(
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
)
current_input_ids = []
else:
current_input_ids.extend(input_ids)
@ -210,14 +208,16 @@ class EasySFTDataset(Dataset):
current_input_ids.extend([tokenizer.pad_token_id] * padded_length)
grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long))
attention_mask.append(
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
)
else:
#just append the raw_input_ids to max_length
# just append the raw_input_ids to max_length
for input_ids in raw_input_ids:
padded_length = max_length - len(input_ids)
input_ids.extend([tokenizer.pad_token_id] * padded_length)
attention_mask.append(
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long))
torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)
)
grouped_input_ids.append(torch.tensor(input_ids, dtype=torch.long))
self.input_ids = grouped_input_ids
self.labels = copy.deepcopy(self.input_ids)
@ -227,14 +227,14 @@ class EasySFTDataset(Dataset):
def __len__(self):
return len(self.input_ids)
#get item from dataset
# get item from dataset
def __getitem__(self, idx):
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
#generate the dataset description to be printed by print in python
# generate the dataset description to be printed by print in python
def __repr__(self):
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"
#generate the dataset description to be printed by print in python
# generate the dataset description to be printed by print in python
def __str__(self):
return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})"

View File

@ -4,7 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from coati.models.generation import generate
from coati.models.utils import log_probs_from_logits, masked_mean
from coati.models.utils import log_probs_from_logits
from peft import PeftModel
from torch.nn.modules import Module
from transformers import BloomConfig, BloomForCausalLM
@ -24,38 +24,33 @@ class Actor(Module):
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
return_action_mask: bool = True,
**kwargs
self, input_ids: torch.Tensor, return_action_mask: bool = True, **kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
sequences = generate(self.model, input_ids, **kwargs)
attention_mask = None
pad_token_id = kwargs.get('pad_token_id', None)
pad_token_id = kwargs.get("pad_token_id", None)
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
if not return_action_mask:
return sequences, attention_mask, None
input_len = input_ids.size(1)
eos_token_id = kwargs.get('eos_token_id', None)
eos_token_id = kwargs.get("eos_token_id", None)
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
# left padding may be applied, only mask action
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
action_mask[:, :input_len] = False
action_mask = action_mask[:, 1:]
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len) :]
def forward(self,
sequences: torch.LongTensor,
num_actions: int,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Returns action log probs
"""
def forward(
self, sequences: torch.LongTensor, num_actions: int, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Returns action log probs"""
output = self.model(sequences, attention_mask=attention_mask)
logits = output['logits']
logits = output["logits"]
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]
@ -75,11 +70,13 @@ class BLOOMActor(Actor):
lora_train_bias (str): LoRA bias training mode.
"""
def __init__(self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
checkpoint: bool = False,
lora_path: str = None) -> None:
def __init__(
self,
pretrained: str = None,
config: Optional[BloomConfig] = None,
checkpoint: bool = False,
lora_path: str = None,
) -> None:
if pretrained is not None:
model = BloomForCausalLM.from_pretrained(pretrained)
elif config is not None:

View File

@ -1,18 +1,16 @@
import argparse
import pandas as pd
import torch
import torch.distributed as dist
from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset
from coati.dataset import DataCollatorForSupervisedDataset
from coati.models.bloom import BLOOMRM, BLOOMCritic
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.gpt import GPTRM, GPTCritic
from coati.models.llama import LlamaCritic, LlamaRM
from coati.models.opt import OPTRM, OPTCritic
from coati.trainer import PPOTrainer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from easy_dataset import EasyPromptsDataset, EasySupervisedDataset
from easy_models import BLOOMActor
from peft import PeftModel
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
@ -23,24 +21,24 @@ from colossalai.nn.optimizer import HybridAdam
def main(args):
# configure strategy
if args.strategy == 'ddp':
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = GeminiStrategy(placement_policy='cpu', initial_scale=2**5)
elif args.strategy == 'colossalai_zero2':
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu')
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
if args.rm_path is not None:
state_dict = torch.load(args.rm_path, map_location='cpu')
state_dict = torch.load(args.rm_path, map_location="cpu")
# configure model
if args.model == 'bloom':
if args.model == "bloom":
# initial_model = BLOOMActor(pretrained=args.pretrain)
print('Using peft lora to load Bloom model as initial_model')
print("Using peft lora to load Bloom model as initial_model")
initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
print('Using peft lora to load Bloom model as initial_model (Done)')
print("Using peft lora to load Bloom model as initial_model (Done)")
else:
raise ValueError(f'Unsupported actor model "{args.model}"')
@ -49,59 +47,59 @@ def main(args):
else:
rm_model_name = args.rm_model
if rm_model_name == 'gpt2':
if rm_model_name == "gpt2":
reward_model = GPTRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'bloom':
elif rm_model_name == "bloom":
print("load bloom reward model ", args.rm_pretrain)
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'opt':
elif rm_model_name == "opt":
reward_model = OPTRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'llama':
elif rm_model_name == "llama":
reward_model = LlamaRM(pretrained=args.rm_pretrain)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
if args.rm_path is not None:
print('Loading reward model from', args.rm_path)
print("Loading reward model from", args.rm_path)
reward_model.load_state_dict(state_dict)
if args.strategy != 'colossalai_gemini':
if args.strategy != "colossalai_gemini":
initial_model.to(torch.float16).to(torch.cuda.current_device())
reward_model.to(torch.float16).to(torch.cuda.current_device())
with strategy.model_init_context():
if args.model == 'bloom':
if args.model == "bloom":
# actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
print('Using peft lora to load Bloom model as Actor')
print("Using peft lora to load Bloom model as Actor")
actor = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path)
print('Using peft lora to load Bloom model as Actor (Done)')
print("Using peft lora to load Bloom model as Actor (Done)")
else:
raise ValueError(f'Unsupported actor model "{args.model}"')
if rm_model_name == 'gpt2':
if rm_model_name == "gpt2":
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
elif rm_model_name == 'bloom':
elif rm_model_name == "bloom":
print("load bloom critic ", args.rm_pretrain, " lora_rank ", args.lora_rank, " use_action_mask ", True)
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
print("load bloom critic (Done) ")
elif rm_model_name == 'opt':
elif rm_model_name == "opt":
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
elif rm_model_name == 'llama':
elif rm_model_name == "llama":
critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
if args.rm_path is not None:
print('Loading reward model from', args.rm_path)
print("Loading reward model from", args.rm_path)
critic.load_state_dict(state_dict)
del state_dict
if args.strategy != 'colossalai_gemini':
if args.strategy != "colossalai_gemini":
critic.to(torch.float16).to(torch.cuda.current_device())
actor.to(torch.float16).to(torch.cuda.current_device())
# configure optimizer
if args.strategy.startswith('colossalai'):
if args.strategy.startswith("colossalai"):
actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
else:
@ -109,18 +107,18 @@ def main(args):
critic_optim = Adam(critic.parameters(), lr=1e-7)
# configure tokenizer
if args.model == 'gpt2':
if args.model == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained(args.rm_pretrain)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
elif args.model == "bloom":
tokenizer = BloomTokenizerFast.from_pretrained(args.rm_pretrain)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
elif args.model == "opt":
tokenizer = AutoTokenizer.from_pretrained(args.rm_pretrain)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'llama':
elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
tokenizer.eos_token = '<\s>'
tokenizer.eos_token = "<\s>"
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
@ -132,26 +130,27 @@ def main(args):
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
else:
prompt_sampler = None
prompt_dataloader = DataLoader(prompt_dataset,
shuffle=(prompt_sampler is None),
sampler=prompt_sampler,
batch_size=args.train_batch_size)
prompt_dataloader = DataLoader(
prompt_dataset, shuffle=(prompt_sampler is None), sampler=prompt_sampler, batch_size=args.train_batch_size
)
pretrain_dataset = EasySupervisedDataset(args.pretrain_dataset, tokenizer)
if dist.is_initialized() and dist.get_world_size() > 1:
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
else:
pretrain_sampler = None
pretrain_dataloader = DataLoader(pretrain_dataset,
shuffle=(pretrain_sampler is None),
sampler=pretrain_sampler,
batch_size=args.ptx_batch_size,
collate_fn=data_collator)
pretrain_dataloader = DataLoader(
pretrain_dataset,
shuffle=(pretrain_sampler is None),
sampler=pretrain_sampler,
batch_size=args.ptx_batch_size,
collate_fn=data_collator,
)
def tokenize_fn(texts):
# MUST padding to max length to ensure inputs of all ranks have the same length
# Different length may lead to hang when using gemini, as different generation steps
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()}
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
@ -178,45 +177,46 @@ def main(args):
eos_token_id=tokenizer.eos_token_id,
)
trainer.fit(prompt_dataloader=prompt_dataloader,
pretrain_dataloader=pretrain_dataloader,
num_episodes=args.num_episodes,
num_update_steps=args.num_update_steps,
num_collect_steps=args.num_collect_steps)
trainer.fit(
prompt_dataloader=prompt_dataloader,
pretrain_dataloader=pretrain_dataloader,
num_episodes=args.num_episodes,
num_update_steps=args.num_update_steps,
num_collect_steps=args.num_collect_steps,
)
# save model checkpoint after fitting
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(actor_optim,
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)
strategy.save_optimizer(
actor_optim, "actor_optim_checkpoint_prompts_%d.pt" % (torch.cuda.current_device()), only_rank0=False
)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--prompt_path', type=str, default=None, help='path to the prompt dataset')
parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset')
parser.add_argument('--strategy',
choices=['ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='ddp',
help='strategy to use')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--sft_lora_path', type=str, default=None)
parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama'])
parser.add_argument('--rm_path', type=str, default=None)
parser.add_argument('--rm_pretrain', type=str, default=None)
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--num_episodes', type=int, default=10)
parser.add_argument('--num_collect_steps', type=int, default=10)
parser.add_argument('--num_update_steps', type=int, default=5)
parser.add_argument('--train_batch_size', type=int, default=2)
parser.add_argument('--ptx_batch_size', type=int, default=1)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--kl_coef', type=float, default=0.1)
parser.add_argument('--ptx_coef', type=float, default=0.9)
parser.add_argument("--prompt_path", type=str, default=None, help="path to the prompt dataset")
parser.add_argument("--pretrain_dataset", type=str, default=None, help="path to the pretrained dataset")
parser.add_argument(
"--strategy", choices=["ddp", "colossalai_gemini", "colossalai_zero2"], default="ddp", help="strategy to use"
)
parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--sft_lora_path", type=str, default=None)
parser.add_argument("--rm_model", default=None, choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument("--rm_path", type=str, default=None)
parser.add_argument("--rm_pretrain", type=str, default=None)
parser.add_argument("--save_path", type=str, default="actor_checkpoint_prompts")
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
parser.add_argument("--num_episodes", type=int, default=10)
parser.add_argument("--num_collect_steps", type=int, default=10)
parser.add_argument("--num_update_steps", type=int, default=5)
parser.add_argument("--train_batch_size", type=int, default=2)
parser.add_argument("--ptx_batch_size", type=int, default=1)
parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--kl_coef", type=float, default=0.1)
parser.add_argument("--ptx_coef", type=float, default=0.9)
args = parser.parse_args()
main(args)

Some files were not shown because too many files have changed in this diff Show More