adapt new fx

This commit is contained in:
oahzxl
2023-01-10 11:56:00 +08:00
parent e532679c95
commit 7ab2db206f
4 changed files with 12 additions and 14 deletions

View File

@@ -585,9 +585,9 @@ if CODEGEN_AVAILABLE:
code = "".join(body)
code = "\n".join(" " + line for line in code.split("\n"))
fn_code = f"""
{wrap_stmts}
{wrap_stmts}
{prologue}
{code}"""
{prologue}
{code}"""
# print(fn_code)
return PythonCode(fn_code, globals_)

View File

@@ -28,12 +28,7 @@ class EstimateMemory(object):
return x
def _get_output_node(self, n):
fwd_out = {
x.uuid: x
for x in n.meta["fwd_out"]
if isinstance(x, torch.Tensor) and hasattr(x, "uuid")
}
out_size = activation_size(fwd_out)
out_size = activation_size(n.meta["fwd_out"])
out_node = [n.name] if out_size > 0 else []
return out_size, out_node