mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 10:30:03 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -3,13 +3,11 @@
|
||||
import cv2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
|
||||
from ldm.modules.midas.midas.midas_net import MidasNet
|
||||
from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
|
||||
from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
|
||||
|
||||
from ldm.modules.midas.midas.transforms import NormalizeImage, PrepareForNet, Resize
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
ISL_PATHS = {
|
||||
"dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
|
||||
@@ -98,18 +96,20 @@ def load_model(model_type):
|
||||
model = MidasNet(model_path, non_negative=True)
|
||||
net_w, net_h = 384, 384
|
||||
resize_mode = "upper_bound"
|
||||
normalization = NormalizeImage(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
elif model_type == "midas_v21_small":
|
||||
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
|
||||
non_negative=True, blocks={'expand': True})
|
||||
model = MidasNet_small(
|
||||
model_path,
|
||||
features=64,
|
||||
backbone="efficientnet_lite3",
|
||||
exportable=True,
|
||||
non_negative=True,
|
||||
blocks={"expand": True},
|
||||
)
|
||||
net_w, net_h = 256, 256
|
||||
resize_mode = "upper_bound"
|
||||
normalization = NormalizeImage(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
else:
|
||||
print(f"model_type '{model_type}' not implemented, use: --model_type large")
|
||||
@@ -135,11 +135,7 @@ def load_model(model_type):
|
||||
|
||||
|
||||
class MiDaSInference(nn.Module):
|
||||
MODEL_TYPES_TORCH_HUB = [
|
||||
"DPT_Large",
|
||||
"DPT_Hybrid",
|
||||
"MiDaS_small"
|
||||
]
|
||||
MODEL_TYPES_TORCH_HUB = ["DPT_Large", "DPT_Hybrid", "MiDaS_small"]
|
||||
MODEL_TYPES_ISL = [
|
||||
"dpt_large",
|
||||
"dpt_hybrid",
|
||||
@@ -149,7 +145,7 @@ class MiDaSInference(nn.Module):
|
||||
|
||||
def __init__(self, model_type):
|
||||
super().__init__()
|
||||
assert (model_type in self.MODEL_TYPES_ISL)
|
||||
assert model_type in self.MODEL_TYPES_ISL
|
||||
model, _ = load_model(model_type)
|
||||
self.model = model
|
||||
self.model.train = disabled_train
|
||||
@@ -167,4 +163,3 @@ class MiDaSInference(nn.Module):
|
||||
)
|
||||
assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
|
||||
return prediction
|
||||
|
||||
|
Reference in New Issue
Block a user