Skip to content

Commit 82b69a8

Browse files
author
Baichuan Sun
committed
add: model and torch inference
1 parent 71d1c52 commit 82b69a8

File tree

3 files changed

+1551
-0
lines changed

3 files changed

+1551
-0
lines changed

deployment/model.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from fastai.vision.all import *
2+
from fastai.vision.learner import _default_meta
3+
from fastai.vision.models.unet import _get_sz_change_idxs, UnetBlock, ResizeToOrig
4+
5+
6+
class DynamicUnetDIY(SequentialEx):
7+
"Create a U-Net from a given architecture."
8+
9+
def __init__(
10+
self,
11+
arch=resnet34,
12+
n_classes=6,
13+
img_size=(600, 400),
14+
blur=False,
15+
blur_final=True,
16+
y_range=None,
17+
last_cross=True,
18+
bottle=False,
19+
init=nn.init.kaiming_normal_,
20+
norm_type=None,
21+
self_attention=True,
22+
act_cls=Mish,
23+
n_in=3,
24+
cut=None,
25+
**kwargs
26+
):
27+
meta = model_meta.get(arch, _default_meta)
28+
encoder = create_body(arch, n_in, False, ifnone(cut, meta["cut"]))
29+
imsize = img_size
30+
31+
sizes = model_sizes(encoder, size=imsize)
32+
sz_chg_idxs = list(reversed(_get_sz_change_idxs(sizes)))
33+
self.sfs = hook_outputs([encoder[i] for i in sz_chg_idxs], detach=False)
34+
x = dummy_eval(encoder, imsize).detach()
35+
36+
ni = sizes[-1][1]
37+
middle_conv = nn.Sequential(
38+
ConvLayer(ni, ni * 2, act_cls=act_cls, norm_type=norm_type, **kwargs),
39+
ConvLayer(ni * 2, ni, act_cls=act_cls, norm_type=norm_type, **kwargs),
40+
).eval()
41+
x = middle_conv(x)
42+
layers = [encoder, BatchNorm(ni), nn.ReLU(), middle_conv]
43+
44+
for i, idx in enumerate(sz_chg_idxs):
45+
not_final = i != len(sz_chg_idxs) - 1
46+
up_in_c, x_in_c = int(x.shape[1]), int(sizes[idx][1])
47+
do_blur = blur and (not_final or blur_final)
48+
sa = self_attention and (i == len(sz_chg_idxs) - 3)
49+
unet_block = UnetBlock(
50+
up_in_c,
51+
x_in_c,
52+
self.sfs[i],
53+
final_div=not_final,
54+
blur=do_blur,
55+
self_attention=sa,
56+
act_cls=act_cls,
57+
init=init,
58+
norm_type=norm_type,
59+
**kwargs
60+
).eval()
61+
layers.append(unet_block)
62+
x = unet_block(x)
63+
64+
ni = x.shape[1]
65+
if imsize != sizes[0][-2:]:
66+
layers.append(PixelShuffle_ICNR(ni, act_cls=act_cls, norm_type=norm_type))
67+
layers.append(ResizeToOrig())
68+
if last_cross:
69+
layers.append(MergeLayer(dense=True))
70+
ni += in_channels(encoder)
71+
layers.append(
72+
ResBlock(
73+
1,
74+
ni,
75+
ni // 2 if bottle else ni,
76+
act_cls=act_cls,
77+
norm_type=norm_type,
78+
**kwargs
79+
)
80+
)
81+
layers += [
82+
ConvLayer(ni, n_classes, ks=1, act_cls=None, norm_type=norm_type, **kwargs)
83+
]
84+
apply_init(nn.Sequential(layers[3], layers[-2]), init)
85+
# apply_init(nn.Sequential(layers[2]), init)
86+
if y_range is not None:
87+
layers.append(SigmoidRange(*y_range))
88+
super().__init__(*layers)
89+
90+
def __del__(self):
91+
if hasattr(self, "sfs"):
92+
self.sfs.remove()

0 commit comments

Comments
 (0)