mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -10,7 +10,7 @@ class ValPosition:
|
||||
offset: int
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = f'[partition_id:{self.partition_id},offset:{self.offset}]'
|
||||
res = f"[partition_id:{self.partition_id},offset:{self.offset}]"
|
||||
return res
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -18,7 +18,6 @@ class ValPosition:
|
||||
|
||||
|
||||
class PartitionInputVal(object):
|
||||
|
||||
def __init__(self, partition_id, offset) -> None:
|
||||
# every input from which partition_id and which offset
|
||||
val_pos = ValPosition(partition_id, offset)
|
||||
@@ -28,8 +27,8 @@ class PartitionInputVal(object):
|
||||
return self._from_partition_and_offset
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
res += f'<-({self._from_partition_and_offset})'
|
||||
res = ""
|
||||
res += f"<-({self._from_partition_and_offset})"
|
||||
return res
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -37,7 +36,6 @@ class PartitionInputVal(object):
|
||||
|
||||
|
||||
class PartitionOutputVal(object):
|
||||
|
||||
def __init__(self) -> None:
|
||||
# every output to which partition_id and which offset
|
||||
self._to_partition_and_offset: List[ValPosition] = []
|
||||
@@ -50,11 +48,11 @@ class PartitionOutputVal(object):
|
||||
return self._to_partition_and_offset
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
res += '->('
|
||||
res = ""
|
||||
res += "->("
|
||||
for val_pos in self._to_partition_and_offset:
|
||||
res += f'{val_pos},'
|
||||
res += ')'
|
||||
res += f"{val_pos},"
|
||||
res += ")"
|
||||
return res
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -62,7 +60,6 @@ class PartitionOutputVal(object):
|
||||
|
||||
|
||||
class Partition(object):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._input_vals: List[PartitionInputVal] = []
|
||||
self._output_vals: List[PartitionOutputVal] = []
|
||||
@@ -110,16 +107,16 @@ class Partition(object):
|
||||
return res
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
res += f' input:\n'
|
||||
res += f' length:{len(self._input_vals)}\n'
|
||||
res = ""
|
||||
res += f" input:\n"
|
||||
res += f" length:{len(self._input_vals)}\n"
|
||||
for i, input_val in enumerate(self._input_vals):
|
||||
res += f' offset={i}:{input_val}\n'
|
||||
res += f" offset={i}:{input_val}\n"
|
||||
|
||||
res += f' output:\n'
|
||||
res += f' length:{len(self._output_vals)}\n'
|
||||
res += f" output:\n"
|
||||
res += f" length:{len(self._output_vals)}\n"
|
||||
for i, output_val in enumerate(self._output_vals):
|
||||
res += f' offset={i}:{output_val}\n'
|
||||
res += f" offset={i}:{output_val}\n"
|
||||
|
||||
return res
|
||||
|
||||
@@ -140,7 +137,6 @@ class Partition(object):
|
||||
# _input_partition_id: the key represents input_partition
|
||||
# _output_partition_id: the key represents output_partition
|
||||
class Topo(object):
|
||||
|
||||
def __init__(self, input_partition_id=None, output_partition_id=None) -> None:
|
||||
self._partitions: Dict[int, Partition] = {}
|
||||
self._input_partition_id = input_partition_id
|
||||
@@ -162,7 +158,7 @@ class Topo(object):
|
||||
self._partitions[partition_id] = partition
|
||||
|
||||
def get_mid_partitions(self):
|
||||
res = {} #{partition_id: Partition}
|
||||
res = {} # {partition_id: Partition}
|
||||
for partition_id, partition in self._partitions.items():
|
||||
if self._input_partition_id == partition_id or self._output_partition_id == partition_id:
|
||||
continue
|
||||
@@ -186,27 +182,27 @@ class Topo(object):
|
||||
return self._partitions[partition_id]
|
||||
|
||||
def __str__(self) -> str:
|
||||
res = ''
|
||||
res = ""
|
||||
if len(self._partitions) == 0:
|
||||
return 'Empty Topo Graph.'
|
||||
return "Empty Topo Graph."
|
||||
|
||||
input_part = self.get_input_partition()
|
||||
if input_part is not None:
|
||||
res += '{\n'
|
||||
res += f'InputPartition:\n partition_id={self._input_partition_id}\n{input_part}'
|
||||
res += '}\n'
|
||||
res += "{\n"
|
||||
res += f"InputPartition:\n partition_id={self._input_partition_id}\n{input_part}"
|
||||
res += "}\n"
|
||||
|
||||
mid_parts = self.get_mid_partitions()
|
||||
for i, (partition_id, part) in enumerate(mid_parts.items()):
|
||||
res += '{\n'
|
||||
res += f'SubPartition_{i}:\n partition_id={partition_id}\n {part}'
|
||||
res += '}\n'
|
||||
res += "{\n"
|
||||
res += f"SubPartition_{i}:\n partition_id={partition_id}\n {part}"
|
||||
res += "}\n"
|
||||
|
||||
output_part = self.get_output_partition()
|
||||
if output_part is not None:
|
||||
res += '{\n'
|
||||
res += f'OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}'
|
||||
res += '}\n'
|
||||
res += "{\n"
|
||||
res += f"OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}"
|
||||
res += "}\n"
|
||||
|
||||
return res
|
||||
|
||||
|
Reference in New Issue
Block a user