| 
 | 1 | +from fastai.vision.all import *  | 
 | 2 | +from fastai.vision.learner import _default_meta  | 
 | 3 | +from fastai.vision.models.unet import _get_sz_change_idxs, UnetBlock, ResizeToOrig  | 
 | 4 | + | 
 | 5 | + | 
 | 6 | +class DynamicUnetDIY(SequentialEx):  | 
 | 7 | + "Create a U-Net from a given architecture."  | 
 | 8 | + | 
 | 9 | + def __init__(  | 
 | 10 | + self,  | 
 | 11 | + arch=resnet34,  | 
 | 12 | + n_classes=6,  | 
 | 13 | + img_size=(600, 400),  | 
 | 14 | + blur=False,  | 
 | 15 | + blur_final=True,  | 
 | 16 | + y_range=None,  | 
 | 17 | + last_cross=True,  | 
 | 18 | + bottle=False,  | 
 | 19 | + init=nn.init.kaiming_normal_,  | 
 | 20 | + norm_type=None,  | 
 | 21 | + self_attention=True,  | 
 | 22 | + act_cls=Mish,  | 
 | 23 | + n_in=3,  | 
 | 24 | + cut=None,  | 
 | 25 | + **kwargs  | 
 | 26 | + ):  | 
 | 27 | + meta = model_meta.get(arch, _default_meta)  | 
 | 28 | + encoder = create_body(arch, n_in, False, ifnone(cut, meta["cut"]))  | 
 | 29 | + imsize = img_size  | 
 | 30 | + | 
 | 31 | + sizes = model_sizes(encoder, size=imsize)  | 
 | 32 | + sz_chg_idxs = list(reversed(_get_sz_change_idxs(sizes)))  | 
 | 33 | + self.sfs = hook_outputs([encoder[i] for i in sz_chg_idxs], detach=False)  | 
 | 34 | + x = dummy_eval(encoder, imsize).detach()  | 
 | 35 | + | 
 | 36 | + ni = sizes[-1][1]  | 
 | 37 | + middle_conv = nn.Sequential(  | 
 | 38 | + ConvLayer(ni, ni * 2, act_cls=act_cls, norm_type=norm_type, **kwargs),  | 
 | 39 | + ConvLayer(ni * 2, ni, act_cls=act_cls, norm_type=norm_type, **kwargs),  | 
 | 40 | + ).eval()  | 
 | 41 | + x = middle_conv(x)  | 
 | 42 | + layers = [encoder, BatchNorm(ni), nn.ReLU(), middle_conv]  | 
 | 43 | + | 
 | 44 | + for i, idx in enumerate(sz_chg_idxs):  | 
 | 45 | + not_final = i != len(sz_chg_idxs) - 1  | 
 | 46 | + up_in_c, x_in_c = int(x.shape[1]), int(sizes[idx][1])  | 
 | 47 | + do_blur = blur and (not_final or blur_final)  | 
 | 48 | + sa = self_attention and (i == len(sz_chg_idxs) - 3)  | 
 | 49 | + unet_block = UnetBlock(  | 
 | 50 | + up_in_c,  | 
 | 51 | + x_in_c,  | 
 | 52 | + self.sfs[i],  | 
 | 53 | + final_div=not_final,  | 
 | 54 | + blur=do_blur,  | 
 | 55 | + self_attention=sa,  | 
 | 56 | + act_cls=act_cls,  | 
 | 57 | + init=init,  | 
 | 58 | + norm_type=norm_type,  | 
 | 59 | + **kwargs  | 
 | 60 | + ).eval()  | 
 | 61 | + layers.append(unet_block)  | 
 | 62 | + x = unet_block(x)  | 
 | 63 | + | 
 | 64 | + ni = x.shape[1]  | 
 | 65 | + if imsize != sizes[0][-2:]:  | 
 | 66 | + layers.append(PixelShuffle_ICNR(ni, act_cls=act_cls, norm_type=norm_type))  | 
 | 67 | + layers.append(ResizeToOrig())  | 
 | 68 | + if last_cross:  | 
 | 69 | + layers.append(MergeLayer(dense=True))  | 
 | 70 | + ni += in_channels(encoder)  | 
 | 71 | + layers.append(  | 
 | 72 | + ResBlock(  | 
 | 73 | + 1,  | 
 | 74 | + ni,  | 
 | 75 | + ni // 2 if bottle else ni,  | 
 | 76 | + act_cls=act_cls,  | 
 | 77 | + norm_type=norm_type,  | 
 | 78 | + **kwargs  | 
 | 79 | + )  | 
 | 80 | + )  | 
 | 81 | + layers += [  | 
 | 82 | + ConvLayer(ni, n_classes, ks=1, act_cls=None, norm_type=norm_type, **kwargs)  | 
 | 83 | + ]  | 
 | 84 | + apply_init(nn.Sequential(layers[3], layers[-2]), init)  | 
 | 85 | + # apply_init(nn.Sequential(layers[2]), init)  | 
 | 86 | + if y_range is not None:  | 
 | 87 | + layers.append(SigmoidRange(*y_range))  | 
 | 88 | + super().__init__(*layers)  | 
 | 89 | + | 
 | 90 | + def __del__(self):  | 
 | 91 | + if hasattr(self, "sfs"):  | 
 | 92 | + self.sfs.remove()  | 
0 commit comments