Skip to content

Commit 8623e48

Browse files
authored
Add python API for backward regularization ops (#5135)
* Add regularizer code * Fix code
1 parent be00b0c commit 8623e48

File tree

4 files changed

+147
-0
lines changed

4 files changed

+147
-0
lines changed

python/paddle/v2/framework/framework.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,8 @@ def __init__(self, block, shape, dtype, **kwargs):
505505

506506
self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0})
507507

508+
self.regularizer = kwargs.get('regularizer', None)
509+
508510

509511
# program is a global instance.
510512
g_program = Program()

python/paddle/v2/framework/optimizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import paddle.v2.framework.framework as framework
44
from paddle.v2.framework.backward import append_backward_ops
5+
from paddle.v2.framework.regularizer import append_regularization_ops
56

67
__all__ = [
78
'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer',
@@ -161,6 +162,8 @@ def minimize(self, loss, parameter_list=None, no_grad_set=None):
161162
"""
162163
params_grads = append_backward_ops(loss, parameter_list, no_grad_set or
163164
set())
165+
# Add regularization if any
166+
params_grads = append_regularization_ops(params_grads)
164167
optimize_ops = self.create_optimization_pass(params_grads, loss)
165168
return optimize_ops
166169

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import paddle.v2.framework.framework as framework
2+
3+
__all__ = ['append_regularization_ops', 'L2DecayRegularizer']
4+
5+
6+
def append_regularization_ops(parameters_and_grads):
7+
"""Create and add backward regularization Operators
8+
9+
Creates and adds backward regularization operators in the BlockDesc.
10+
This will add gradients of the regularizer function to the gradients
11+
of the parameters and return these modified gradients. This is the
12+
same as implementing weight decay in optimizers for regularization.
13+
14+
Args:
15+
parameters_and_grads: A list of (parameters, gradients) pairs
16+
that need to be regularized.
17+
18+
Returns:
19+
list of (parameters, gradients) pair with the regularized gradient
20+
21+
Raises:
22+
Exception: Unknown regularization type
23+
"""
24+
params_and_grads = []
25+
for param, grad in parameters_and_grads:
26+
# If no gradient or no regularization specified,
27+
# then we don't need to do anything
28+
if grad is None or param.regularizer is None:
29+
params_and_grads.append((param, grad))
30+
continue
31+
32+
# Add variable for regularization term in grad block
33+
regularization_term = param.regularizer(param, grad.block)
34+
assert grad.shape == regularization_term.shape
35+
36+
grad.block.append_op(
37+
type='elementwise_add',
38+
inputs={"X": grad,
39+
"Y": regularization_term},
40+
outputs={"Out": grad})
41+
params_and_grads.append((param, grad))
42+
43+
return params_and_grads
44+
45+
46+
class WeightDecayRegularizer(object):
47+
"""Base class for weight decay regularizers
48+
49+
Defines the common interface of weight-decay regularizers.
50+
Weight-decay regularizers are added only during the backward
51+
pass for faster regularization. They add operations to the network
52+
that correspond to gradient of the regularization function.
53+
Users should not use this class directly, but need to use one
54+
of its implementations
55+
"""
56+
57+
def __init__(self):
58+
pass
59+
60+
def __call__(self, param, block):
61+
"""Add corresponding weight decay operations to the network
62+
"""
63+
raise NotImplementedError()
64+
65+
66+
class L2DecayRegularizer(WeightDecayRegularizer):
67+
"""Implements the L2 Weight Decay Regularization
68+
"""
69+
70+
def __init__(self, regularization_coeff=0.0):
71+
assert regularization_coeff is not None
72+
super(L2DecayRegularizer, self).__init__()
73+
self._regularization_coeff = regularization_coeff
74+
75+
def __call__(self, param, block):
76+
"""Add L2 weight decay ops to network
77+
78+
Adds L2 weight decay ops.
79+
L2WeightDecay = reg_coeff * parameter
80+
81+
Args:
82+
param: parameter variable for which regularization is applied
83+
block: block in which variable is to be created
84+
85+
Returns:
86+
new variable for weight decay
87+
"""
88+
assert isinstance(param, framework.Parameter)
89+
assert isinstance(block, framework.Block)
90+
decay = block.create_var(
91+
dtype="float32", shape=param.shape, lod_level=param.lod_level)
92+
# Append Op to calculate decay
93+
block.append_op(
94+
type='scale',
95+
inputs={"X": param},
96+
outputs={"Out": decay},
97+
attrs={"scale": self._regularization_coeff})
98+
99+
return decay
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import unittest
2+
3+
import paddle.v2.framework.framework as framework
4+
import paddle.v2.framework.optimizer as optimizer
5+
import paddle.v2.framework.regularizer as regularizer
6+
from paddle.v2.framework.backward import append_backward_ops
7+
8+
9+
class TestL2DecayRegularizer(unittest.TestCase):
10+
def test_l2decay_regularizer(self):
11+
program = framework.Program()
12+
block = program.global_block()
13+
mul_x = block.create_parameter(
14+
dtype="float32",
15+
shape=[5, 10],
16+
lod_level=0,
17+
name="mul.x",
18+
regularizer=regularizer.L2DecayRegularizer(0.5))
19+
self.assertTrue(mul_x.regularizer is not None)
20+
self.assertTrue(
21+
isinstance(mul_x.regularizer, regularizer.L2DecayRegularizer))
22+
mul_y = block.create_var(
23+
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
24+
mul_out = block.create_var(
25+
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
26+
block.append_op(
27+
type="mul",
28+
inputs={"X": mul_x,
29+
"Y": mul_y},
30+
outputs={"Out": mul_out},
31+
attrs={"x_num_col_dims": 1})
32+
params_grads = append_backward_ops(mul_out)
33+
self.assertEqual(len(params_grads), 1)
34+
count_ops = len(block.ops)
35+
params_grads = optimizer.append_regularization_ops(params_grads)
36+
self.assertEqual(len(params_grads), 1)
37+
self.assertEqual(len(block.ops), count_ops + 2)
38+
self.assertEqual(block.ops[-1].type, 'elementwise_add')
39+
self.assertEqual(block.ops[-2].type, 'scale')
40+
41+
42+
if __name__ == '__main__':
43+
unittest.main()

0 commit comments

Comments
 (0)