22import torch .nn as nn
33import torch .nn .functional as F
44
5+
56class BasicBlock (nn .Module ):
6- def __init__ (self , c_in , c_out ,is_downsample = False ):
7- super (BasicBlock ,self ).__init__ ()
7+ def __init__ (self , c_in , c_out , is_downsample = False ):
8+ super (BasicBlock , self ).__init__ ()
89 self .is_downsample = is_downsample
910 if is_downsample :
1011 self .conv1 = nn .Conv2d (c_in , c_out , 3 , stride = 2 , padding = 1 , bias = False )
1112 else :
1213 self .conv1 = nn .Conv2d (c_in , c_out , 3 , stride = 1 , padding = 1 , bias = False )
1314 self .bn1 = nn .BatchNorm2d (c_out )
1415 self .relu = nn .ReLU (True )
15- self .conv2 = nn .Conv2d (c_out ,c_out ,3 , stride = 1 ,padding = 1 , bias = False )
16+ self .conv2 = nn .Conv2d (c_out , c_out , 3 , stride = 1 , padding = 1 , bias = False )
1617 self .bn2 = nn .BatchNorm2d (c_out )
1718 if is_downsample :
1819 self .downsample = nn .Sequential (
@@ -26,48 +27,50 @@ def __init__(self, c_in, c_out,is_downsample=False):
2627 )
2728 self .is_downsample = True
2829
29- def forward (self ,x ):
30+ def forward (self , x ):
3031 y = self .conv1 (x )
3132 y = self .bn1 (y )
3233 y = self .relu (y )
3334 y = self .conv2 (y )
3435 y = self .bn2 (y )
3536 if self .is_downsample :
3637 x = self .downsample (x )
37- return F .relu (x .add (y ),True )
38+ return F .relu (x .add (y ), True )
39+
3840
39- def make_layers (c_in ,c_out ,repeat_times , is_downsample = False ):
41+ def make_layers (c_in , c_out , repeat_times , is_downsample = False ):
4042 blocks = []
4143 for i in range (repeat_times ):
42- if i == 0 :
43- blocks += [BasicBlock (c_in ,c_out , is_downsample = is_downsample ),]
44+ if i == 0 :
45+ blocks += [BasicBlock (c_in , c_out , is_downsample = is_downsample ), ]
4446 else :
45- blocks += [BasicBlock (c_out ,c_out ),]
47+ blocks += [BasicBlock (c_out , c_out ), ]
4648 return nn .Sequential (* blocks )
4749
50+
4851class Net (nn .Module ):
49- def __init__ (self , num_classes = 751 , reid = False ):
50- super (Net ,self ).__init__ ()
52+ def __init__ (self , num_classes = 751 , reid = False ):
53+ super (Net , self ).__init__ ()
5154 # 3 128 64
5255 self .conv = nn .Sequential (
53- nn .Conv2d (3 ,64 ,3 , stride = 1 ,padding = 1 ),
56+ nn .Conv2d (3 , 64 , 3 , stride = 1 , padding = 1 ),
5457 nn .BatchNorm2d (64 ),
5558 nn .ReLU (inplace = True ),
5659 # nn.Conv2d(32,32,3,stride=1,padding=1),
5760 # nn.BatchNorm2d(32),
5861 # nn.ReLU(inplace=True),
59- nn .MaxPool2d (3 ,2 , padding = 1 ),
62+ nn .MaxPool2d (3 , 2 , padding = 1 ),
6063 )
6164 # 32 64 32
62- self .layer1 = make_layers (64 ,64 ,2 , False )
65+ self .layer1 = make_layers (64 , 64 , 2 , False )
6366 # 32 64 32
64- self .layer2 = make_layers (64 ,128 ,2 , True )
67+ self .layer2 = make_layers (64 , 128 , 2 , True )
6568 # 64 32 16
66- self .layer3 = make_layers (128 ,256 ,2 , True )
69+ self .layer3 = make_layers (128 , 256 , 2 , True )
6770 # 128 16 8
68- self .layer4 = make_layers (256 ,512 ,2 , True )
71+ self .layer4 = make_layers (256 , 512 , 2 , True )
6972 # 256 8 4
70- self .avgpool = nn .AvgPool2d (( 8 , 4 ), 1 )
73+ self .avgpool = nn .AdaptiveAvgPool2d ( 1 )
7174 # 256 1 1
7275 self .reid = reid
7376 self .classifier = nn .Sequential (
@@ -77,18 +80,18 @@ def __init__(self, num_classes=751 ,reid=False):
7780 nn .Dropout (),
7881 nn .Linear (256 , num_classes ),
7982 )
80-
83+
8184 def forward (self , x ):
8285 x = self .conv (x )
8386 x = self .layer1 (x )
8487 x = self .layer2 (x )
8588 x = self .layer3 (x )
8689 x = self .layer4 (x )
8790 x = self .avgpool (x )
88- x = x .view (x .size (0 ),- 1 )
91+ x = x .view (x .size (0 ), - 1 )
8992 # B x 128
9093 if self .reid :
91- x = x .div (x .norm (p = 2 ,dim = 1 ,keepdim = True ))
94+ x = x .div (x .norm (p = 2 , dim = 1 , keepdim = True ))
9295 return x
9396 # classifier
9497 x = self .classifier (x )
@@ -97,8 +100,6 @@ def forward(self, x):
97100
98101if __name__ == '__main__' :
99102 net = Net ()
100- x = torch .randn (4 ,3 , 128 ,64 )
103+ x = torch .randn (4 , 3 , 128 , 64 )
101104 y = net (x )
102- import ipdb ; ipdb .set_trace ()
103-
104105
0 commit comments