@@ -434,6 +434,7 @@ def __init__(self, upsampling_factor):
434434 self .pool = torch .nn .AvgPool2d (kernel_size = upsampling_factor )
435435 self .upsampling_factor = upsampling_factor
436436 def forward (self , y , lr ):
437+ y = y .clone ()
437438 out = self .pool (y )
438439 out = y * torch .kron (lr * 1 / out , torch .ones ((self .upsampling_factor ,self .upsampling_factor )).to ('cuda' ))
439440 return out
@@ -700,10 +701,84 @@ def forward(self, x, mr=None, z=None):
700701 # out[:,0,i,...] = self.constraints(out, x[:,0,i,...])
701702 #out[:,0,:,:] *= 16
702703 out = out .unsqueeze (1 )
704+ return out
705+
706+ class ResNet3 (nn .Module ):
707+ def __init__ (self , number_channels = 64 , number_residual_blocks = 4 , upsampling_factor = 2 , noise = False , constraints = 'none' , dim = 1 , cwindow_size = 2 ):
708+ super (ResNet3 , self ).__init__ ()
709+ # First layer
710+ if noise :
711+ self .conv_trans0 = nn .ConvTranspose2d (100 , 1 , kernel_size = (32 ,32 ), padding = 0 , stride = 1 )
712+ self .conv1 = nn .Sequential (nn .Conv2d (dim , number_channels , kernel_size = 3 , stride = 1 , padding = 1 ), nn .ReLU (inplace = True ))
713+ else :
714+ self .conv1 = nn .Sequential (nn .Conv2d (dim , number_channels , kernel_size = 3 , stride = 1 , padding = 1 ), nn .ReLU (inplace = True ))
715+ #Residual Blocks
716+ self .res_blocks = nn .ModuleList ()
717+ for k in range (number_residual_blocks ):
718+ self .res_blocks .append (ResidualBlock (number_channels , number_channels ))
719+ # Second conv layer post residual blocks
720+ self .conv2 = nn .Sequential (
721+ nn .Conv2d (number_channels , number_channels , kernel_size = 3 , stride = 1 , padding = 1 ), nn .ReLU (inplace = True ))
722+ # Upsampling layers
723+ self .upsampling = nn .ModuleList ()
724+ for k in range (1 ):
725+ self .upsampling .append (nn .ConvTranspose2d (number_channels , number_channels , kernel_size = 3 , padding = 0 , stride = 3 ) )
726+ # Next layer after upper sampling
727+ self .conv3 = nn .Sequential (nn .Conv2d (number_channels , number_channels , kernel_size = 3 , stride = 1 , padding = 1 ), nn .ReLU (inplace = True ))
728+ # Final output layer
729+ self .conv4 = nn .Conv2d (number_channels , dim , kernel_size = 1 , stride = 1 , padding = 0 )
730+ #optional renomralization layer
731+ self .is_constraints = False
732+ if constraints == 'softmax' :
733+ self .constraints = SoftmaxConstraints (upsampling_factor = upsampling_factor , cwindow_size = cwindow_size )
734+ self .is_constraints = True
735+ elif constraints == 'enforce_op' :
736+ self .constraints = EnforcementOperator (upsampling_factor = upsampling_factor )
737+ self .is_constraints = True
738+ elif constraints == 'add' :
739+ self .constraints = AddDownscaleConstraints (upsampling_factor = upsampling_factor )
740+ self .is_constraints = True
741+ elif constraints == 'mult' :
742+ self .constraints = MultDownscaleConstraints (upsampling_factor = upsampling_factor )
743+ self .is_constraints = True
744+
745+ self .dim = dim
746+ self .noise = noise
747+
748+ def forward (self , x , mr = None , z = None ):
749+ if self .noise :
750+ out = self .conv_trans0 (z )
751+ out = self .conv1 (torch .cat (( x [:,0 ,...],out ), dim = 1 ))
752+ for layer in self .res_blocks :
753+ out = layer (out )
754+ out = self .conv2 (out )
755+ for layer in self .upsampling :
756+ out = layer (out )
757+ out = self .conv3 (out )
758+ out = self .conv4 (out )
759+ if self .is_constraints :
760+ out = self .constraints (out , x [:,0 ,...])
703761 return out
762+ else :
763+ #print(x.shape)
764+ out = self .conv1 (x [:,0 ,...])
765+ for layer in self .upsampling :
766+ out = layer (out )
767+ out = self .conv2 (out )
768+ for layer in self .res_blocks :
769+ out = layer (out )
770+ out = self .conv3 (out )
771+ out = self .conv4 (out )
772+ if self .is_constraints :
773+ out [:,...] = self .constraints (out , x [:,0 ,...])
774+ #for i in range(self.dim):
775+ # out[:,0,i,...] = self.constraints(out, x[:,0,i,...])
776+ #out[:,0,:,:] *= 16
777+ out = out .unsqueeze (1 )
778+ return out
704779
705780class ResNetNoise (nn .Module ):
706- def __init__ (self , number_channels = 64 , number_residual_blocks = 4 , upsampling_factor = 2 , noise = False , constraints = 'none' , dim = 1 ):
781+ def __init__ (self , number_channels = 64 , number_residual_blocks = 4 , upsampling_factor = 2 , noise = False , constraints = 'none' , dim = 1 , cwindow_size = 2 ):
707782 super (ResNetNoise , self ).__init__ ()
708783 # First layer
709784
@@ -728,7 +803,7 @@ def __init__(self, number_channels=64, number_residual_blocks=4, upsampling_fact
728803 #optional renomralization layer
729804 self .is_constraints = False
730805 if constraints == 'softmax' :
731- self .constraints = SoftmaxConstraints (upsampling_factor = upsampling_factor )
806+ self .constraints = SoftmaxConstraints (upsampling_factor = upsampling_factor , cwindow_size = cwindow_size )
732807 self .is_constraints = True
733808 elif constraints == 'enforce_op' :
734809 self .constraints = EnforcementOperator (upsampling_factor = upsampling_factor )
@@ -817,7 +892,7 @@ def forward(self, x):
817892
818893
819894class ResNet2Up (nn .Module ):
820- def __init__ (self , number_channels = 64 , number_residual_blocks = 4 , upsampling_factor = 2 , noise = False , constraints = 'none' , dim = 1 , output_mr = False ):
895+ def __init__ (self , number_channels = 64 , number_residual_blocks = 4 , upsampling_factor = 2 , noise = False , constraints = 'none' , dim = 1 , output_mr = False , cwindow_size = 2 ):
821896 super (ResNet2Up , self ).__init__ ()
822897 #PART I
823898 self .conv1 = nn .Sequential (nn .Conv2d (dim , number_channels , kernel_size = 3 , stride = 1 , padding = 1 ), nn .ReLU (inplace = True ))
@@ -833,7 +908,7 @@ def __init__(self, number_channels=64, number_residual_blocks=4, upsampling_fact
833908
834909 self .is_constraints = False
835910 if constraints == 'softmax' :
836- self .constraints = SoftmaxConstraints (upsampling_factor = 2 )
911+ self .constraints = SoftmaxConstraints (upsampling_factor = 2 , cwindow_size = cwindow_size )
837912 self .is_constraints = True
838913 elif constraints == 'enforce_op' :
839914 self .constraints = EnforcementOperator (upsampling_factor = 2 )
@@ -844,7 +919,7 @@ def __init__(self, number_channels=64, number_residual_blocks=4, upsampling_fact
844919
845920 self .is_single_constraints = False
846921 if constraints == 'softmax_single' :
847- self .constraints = SoftmaxConstraints (upsampling_factor = 4 )
922+ self .constraints = SoftmaxConstraints (upsampling_factor = 4 , cwindow_size = cwindow_size )
848923 self .is_single_constraints = True
849924 elif constraints == 'enforce_op_single' :
850925 self .constraints = EnforcementOperator (upsampling_factor = 4 )
@@ -865,7 +940,7 @@ def __init__(self, number_channels=64, number_residual_blocks=4, upsampling_fact
865940 self .conv23 = nn .Sequential (nn .Conv2d (number_channels , number_channels , kernel_size = 3 , stride = 1 , padding = 1 ), nn .ReLU (inplace = True ))
866941 self .conv24 = nn .Conv2d (number_channels , dim , kernel_size = 1 , stride = 1 , padding = 0 )
867942 if constraints == 'softmax' :
868- self .constraints2 = SoftmaxConstraints (upsampling_factor = 2 )
943+ self .constraints2 = SoftmaxConstraints (upsampling_factor = 2 , cwindow_size = cwindow_size )
869944 elif constraints == 'enforce_op' :
870945 self .constraints2 = EnforcementOperator (upsampling_factor = 2 )
871946 elif constraints == 'mult' :
0 commit comments