Skip to content

Commit a225bfa

Browse files
authored
Merge pull request #3982 from PaddlePaddle/fix_batch_norm_parameter_share
fix batch_norm parameter share
2 parents 99e3d1e + 80a8e91 commit a225bfa

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

python/paddle/trainer/config_parser.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2368,6 +2368,7 @@ def __init__(self,
23682368
use_global_stats=True,
23692369
moving_average_fraction=0.9,
23702370
batch_norm_type=None,
2371+
mean_var_names=None,
23712372
**xargs):
23722373
if inputs is None:
23732374
inputs = []
@@ -2421,6 +2422,11 @@ def __init__(self,
24212422

24222423
psize = self.calc_parameter_size(image_conf)
24232424
dims = [1, psize]
2425+
if mean_var_names is not None:
2426+
assert len(mean_var_names) == 2
2427+
self.inputs[1].parameter_name = mean_var_names[0]
2428+
self.inputs[2].parameter_name = mean_var_names[1]
2429+
24242430
self.create_input_parameter(0, psize)
24252431
self.create_input_parameter(1, psize, dims)
24262432
self.create_input_parameter(2, psize, dims)

python/paddle/trainer_config_helpers/layers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2959,7 +2959,8 @@ def batch_norm_layer(input,
29592959
layer_attr=None,
29602960
batch_norm_type=None,
29612961
moving_average_fraction=0.9,
2962-
use_global_stats=None):
2962+
use_global_stats=None,
2963+
mean_var_names=None):
29632964
"""
29642965
Batch Normalization Layer. The notation of this layer as follow.
29652966
@@ -3026,6 +3027,8 @@ def batch_norm_layer(input,
30263027
:math:`runningMean = newMean*(1-factor)
30273028
+ runningMean*factor`
30283029
:type moving_average_fraction: float.
3030+
:param mean_var_names: [mean name, variance name]
3031+
:type mean_var_names: string list
30293032
:return: LayerOutput object.
30303033
:rtype: LayerOutput
30313034
"""
@@ -3047,6 +3050,7 @@ def batch_norm_layer(input,
30473050
bias=ParamAttr.to_bias(bias_attr),
30483051
moving_average_fraction=moving_average_fraction,
30493052
use_global_stats=use_global_stats,
3053+
mean_var_names=mean_var_names,
30503054
**ExtraLayerAttribute.to_kwargs(layer_attr))
30513055

30523056
return LayerOutput(

0 commit comments

Comments
 (0)