mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-25 18:13:56 +00:00 
			
		
		
		
	* support n_behind, add profiling * fix bugs * fix visualization * fix behind * fix loop issue * add profiling * fix update * update assert * remove assert --------- Co-authored-by: Tong Li <tong.li35271158@gmail.com>
		
			
				
	
	
		
			101 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			101 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Re-import required libraries due to kernel reset
 | |
| import argparse
 | |
| from collections import defaultdict
 | |
| 
 | |
| import matplotlib.cm as cm
 | |
| import matplotlib.pyplot as plt
 | |
| 
 | |
| # Argument parser for command line arguments
 | |
| parser = argparse.ArgumentParser(description="Process profiling logs and generate a timeline plot.")
 | |
| parser.add_argument("--visualization", type=str, default="actor_timelines.png", help="Path to the visualization file.")
 | |
| args = parser.parse_args()
 | |
| 
 | |
| # Raw log lines
 | |
| log_lines = []
 | |
| 
 | |
| import glob
 | |
| 
 | |
| files = glob.glob("*.prof")
 | |
| for file in files:
 | |
|     with open(file, "r") as f:
 | |
|         log_lines += f.readlines()
 | |
| 
 | |
| # Parse logs and collect function intervals grouped by actor
 | |
| actors = defaultdict(lambda: defaultdict(list))
 | |
| current_entries = {}
 | |
| 
 | |
| # First, collect all timestamps to find the minimum
 | |
| all_timestamps = []
 | |
| parsed_lines = []
 | |
| 
 | |
| for line in log_lines:
 | |
|     if line.startswith("[Log]"):
 | |
|         continue
 | |
|     parts = line.split()
 | |
|     timestamp = float(parts[0])
 | |
|     actor = parts[1]
 | |
|     action = parts[3]
 | |
|     func_name = parts[4]
 | |
|     parsed_lines.append((timestamp, actor, action, func_name))
 | |
|     all_timestamps.append(timestamp)
 | |
| 
 | |
| if not all_timestamps:
 | |
|     raise ValueError("No valid log entries found.")
 | |
| 
 | |
| min_timestamp = min(all_timestamps)
 | |
| 
 | |
| for timestamp, actor, action, func_name in parsed_lines:
 | |
|     rel_timestamp = timestamp - min_timestamp
 | |
|     key = (actor, func_name)
 | |
|     if action == "Enter":
 | |
|         current_entries[key] = rel_timestamp
 | |
|     elif action == "Exit":
 | |
|         start_time = current_entries.pop(key, None)
 | |
|         if start_time is not None:
 | |
|             actors[actor][func_name].append((start_time, rel_timestamp))
 | |
| 
 | |
| # Plotting setup
 | |
| fig, ax = plt.subplots(figsize=(12, 6))
 | |
| colors = cm.get_cmap("tab10", len(actors))
 | |
| 
 | |
| actor_offsets = {}
 | |
| base_offset = 0
 | |
| function_spacing = 0.9
 | |
| 
 | |
| yticks = []
 | |
| yticklabels = []
 | |
| 
 | |
| for idx, (actor, func_dict) in enumerate(actors.items()):
 | |
|     actor_offsets[actor] = base_offset
 | |
|     color = colors(idx)
 | |
|     for j, (func, intervals) in enumerate(func_dict.items()):
 | |
|         print(actor, func, intervals)
 | |
|         y_val = base_offset + j * function_spacing
 | |
|         yticks.append(y_val)
 | |
|         yticklabels.append(f"{actor}:{func}")
 | |
|         for start, end in intervals:
 | |
|             if end - start < 1:
 | |
|                 end = start + 1  # Ensure all lines are at least 3 units long
 | |
|             ax.plot(
 | |
|                 [start, end],
 | |
|                 [y_val, y_val],
 | |
|                 color=color,
 | |
|                 linewidth=2,
 | |
|                 label=actor if j == 0 else "",
 | |
|             )
 | |
|     base_offset += len(func_dict) * function_spacing + 1
 | |
| 
 | |
| # Formatting
 | |
| ax.set_yticks(yticks)
 | |
| ax.set_yticklabels(yticklabels)
 | |
| ax.set_xlabel("Time")
 | |
| ax.set_title("Timeline per Actor")
 | |
| # Remove duplicate labels in legend
 | |
| handles, labels = ax.get_legend_handles_labels()
 | |
| unique = dict(zip(labels, handles))
 | |
| ax.legend(unique.values(), unique.keys())
 | |
| plt.tight_layout()
 | |
| plt.grid(True)
 | |
| plt.savefig(args.visualization, dpi=600)  # Increase dpi for higher resolution
 | |
| print(f"Plot saved as {args.visualization}")
 |