Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
80b884e
native commit for triple grad of sigmod
veyron95 Sep 22, 2021
d52b81c
Updated unittests files
veyron95 Sep 22, 2021
19d6b05
init functional jacobian api
Sep 22, 2021
f47b48f
merge upstream/develop
Sep 22, 2021
16c048a
Merge pull request #2 from veyron95/ops_derivative
JiabinYang Sep 22, 2021
a6a9053
Merge branch 'support_derivative' of https://github.com/JiabinYang/Pa…
JiabinYang Sep 22, 2021
4febae7
Updated trible_test func
veyron95 Sep 22, 2021
be9da74
Updated gradient_checker & test_script
veyron95 Sep 22, 2021
be2b30d
finish test with dtype float32
Sep 23, 2021
36b8c34
add float64 test case
Sep 23, 2021
35b1ce8
polish code
Sep 24, 2021
3a35a00
use atol=1e-5 with dtype float64
Sep 24, 2021
a3ea12e
fix for ci
Sep 24, 2021
8738cf8
set timeout for test_jacobian
Sep 24, 2021
d6e771e
fix dygraph grad to support high differential
JiabinYang Sep 24, 2021
0bd8287
polish API docstring
Sep 26, 2021
83c8395
Merge branch 'support_derivative' of https://github.com/JiabinYang/Pa…
veyron95 Sep 26, 2021
4109fc5
Updated gradient checker and some related files
veyron95 Sep 26, 2021
19e471c
Merge pull request #4 from veyron95/ops_derivative
JiabinYang Sep 26, 2021
1573b2c
Merge branch 'lml/jacobian' of https://github.com/levi131/Paddle into…
JiabinYang Sep 26, 2021
1408ef5
fix double grad strip error for high differential
JiabinYang Sep 26, 2021
ea78b6e
fix double grad strip error for high differential
JiabinYang Sep 26, 2021
2351a99
Add Sigmoid triple grad tests
veyron95 Sep 26, 2021
7a3fbd1
fix dygraph double grad dtype error when calling for high differentia…
JiabinYang Sep 26, 2021
42df611
Merge pull request #8 from veyron95/ops_derivative
JiabinYang Sep 26, 2021
a6dde75
Updated triple grad teses func
veyron95 Sep 27, 2021
848efcf
Use np.random to initialize ddx
veyron95 Sep 27, 2021
04eab89
Updated triple_grad_check func
veyron95 Sep 28, 2021
38ca20a
Merge pull request #9 from veyron95/ops_derivative
JiabinYang Sep 28, 2021
886d9fb
merge develop
JiabinYang Sep 28, 2021
e9f643d
add todo for gradient checker and refine some comments
JiabinYang Sep 28, 2021
2d6370b
remove additional code
JiabinYang Sep 28, 2021
a3b8e4e
add test for infer_var dtype warning
JiabinYang Sep 29, 2021
13af3ed
Merge branch 'support_derivative' of https://github.com/JiabinYang/Pa…
JiabinYang Sep 29, 2021
20ca8e7
add test for warnging in backward.py
JiabinYang Sep 29, 2021
a961e3c
format python code
JiabinYang Oct 11, 2021
ee5489d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JiabinYang Oct 11, 2021
a495960
support multi input in triple gradient checker
JiabinYang Oct 12, 2021
ebe8559
Add matmul triple grad kernel
veyron95 Oct 14, 2021
4f31159
Merge branch 'support_derivative' of https://github.com/JiabinYang/Pa…
veyron95 Oct 14, 2021
4d56a30
Updated comments of TODO
veyron95 Oct 14, 2021
15f2a32
Merge develop branch and all conflicts fixed
veyron95 Oct 14, 2021
07d1490
Supported some special tests
veyron95 Oct 14, 2021
d5fdd20
merge develop
JiabinYang Oct 15, 2021
0e44f39
merge jiabin/support_derivative branch
veyron95 Oct 15, 2021
b52794e
Change code-format to follow CI std
veyron95 Oct 18, 2021
4202d96
Updated gradient_checker.py
veyron95 Oct 19, 2021
91149a7
Fix conflicts
veyron95 Oct 19, 2021
e20ef17
Merge develop and fix conflicts
veyron95 Oct 19, 2021
d0741f4
Removed unnecessary printing log
veyron95 Oct 19, 2021
46dbd64
Change code style to follow CI std
veyron95 Oct 20, 2021
e32e10e
Merge remote-tracking branch '3rd_order/ops_derivative' into develop
Oct 20, 2021
46607df
Merge remote-tracking branch 'upstream/develop' into develop
Oct 20, 2021
9da53dd
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into develop
Oct 22, 2021
528ef73
Merge remote-tracking branch 'upstream/develop' into develop
Nov 13, 2021
36a1dcb
support batch in jacobian and hessian
Nov 15, 2021
9a880bd
add batch jacobian and batch hessian
Nov 20, 2021
205c57f
Add batch_jacobian test, draft version
veyron95 Nov 22, 2021
d021233
[New features] Add elementwise_mul triple grad kernel (#37152)
veyron95 Nov 15, 2021
d3fc2af
Add numerical_batch_jacobian,numerical_batch_hessian and tests
veyron95 Nov 23, 2021
ced8536
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
veyron95 Nov 23, 2021
6a38ac8
Support batch_jacobian and batch_numerical
veyron95 Nov 25, 2021
4c6cb8e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
veyron95 Nov 25, 2021
9735d15
Use pre-commit to check code format
veyron95 Nov 25, 2021
cf9df58
Update doc, polish code, add unit test
veyron95 Nov 26, 2021
5f98f05
Reset the TIMEOUT properties of test_jacobian to pass CI
veyron95 Nov 26, 2021
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/paddle/autograd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .py_layer import PyLayer, PyLayerContext # noqa: F401
from ..framework import set_grad_enabled # noqa: F401
from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from .functional import vjp, jvp, jacobian, hessian, vhp # noqa: F401
from .functional import jacobian, hessian, batch_jacobian, batch_hessian # noqa: F401
from .functional import vjp, jvp, vhp # noqa: F401

__all__ = ['backward', 'PyLayer', 'PyLayerContext']
291 changes: 291 additions & 0 deletions python/paddle/autograd/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,297 @@ def func(x, y):
return jacobian


@framework.dygraph_only
def batch_jacobian(func, inputs, create_graph=False, allow_unused=False):
'''
.. note::
**This API is ONLY available in the imperative mode.**

This function computes the batch Jacobian matrix of `func` with respect to `inputs`.
Noted that the first dimension of inputs is batch size.

Parameters:
func (function): a Python function that takes a Tensor or a Tensor
list/tuple as inputs(the first dimension is batch size) and
returns a Tensor or a Tensor tuple.
inputs (Tensor|list(Tensor)|tuple(Tensor)): the input Tensor or
Tensor list/tuple of the function ``func``, Noted that
the first dimension of inputs is batch size.
create_graph (bool, optional): whether to create the gradient graphs
of the computing process. When it is True, higher order derivatives
are supported to compute; when it is False, the gradient graphs of
the computing process would be discarded. Defaults to ``False``.
allow_unused (bool, optional): whether to raise error or return None if
some Tensors of `inputs` are unreachable in the graph. Error would
be raised if allow_unused=False, and None would be returned as
their gradients if allow_unused=True. Default False.
Returns:
Jacobian (Tensor or nested tuple of Tensors): if function ``func``
takes a Tensor as inputs and returns a Tensor as outputs, Jacobian
will be a single Tensor containing the Jacobian matrix for the
linearized inputs and outputs. If one of the inputs and outputs is
a Tensor, and another is a Tensor list/tuple, then the Jacobian will
be a tuple of Tensors. If both of inputs and outputs are Tensor
list/tuple, then the Jacobian will be a tuple of tuple of Tensors.
Noted that the first dimension of inputs is batch size.

For example,
the inputs shape and outputs shape of function ``func` is [batch_size, num]
and [batch_size, num] respectively, then the Jacobian will be a Tensor with
a shape of [num, batch_size * num], where ``Jacobian[i][j]`` will contain
the Jacobian matrix of the ``i``th column output and the ``j``th input and
will have same dtype and device as the corresponding input.
Other situations can be deduced by analogy.

Examples 1:
.. code-block:: python

import paddle

x = paddle.ones(shape=(4, 2), dtype='float64')
weight = paddle.ones(shape=(2, 4), dtype='float64')
y = paddle.ones(shape=(4, 2), dtype='float64')

def func(x):
return paddle.matmul(paddle.matmul(x, weight), y)

x.stop_gradient = False
batch_jacobian = paddle.autograd.batch_jacobian(func, x)
print(batch_jacobian)
# Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[4., 4., 4., 4., 4., 4., 4., 4.],
# [4., 4., 4., 4., 4., 4., 4., 4.]])

Examples 2:
.. code-block:: python

import paddle

x = paddle.ones(shape=(4, 2), dtype='float64')
weight = paddle.ones(shape=(2, 4), dtype='float64')
y = paddle.ones(shape=(4, 2), dtype='float64')

def func(x):
return paddle.matmul(paddle.matmul(x, weight), y), x * x

x.stop_gradient = False
batch_jacobian = paddle.autograd.batch_jacobian(func, x)
print(batch_jacobian)
# (Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[4., 4., 4., 4., 4., 4., 4., 4.],
# [4., 4., 4., 4., 4., 4., 4., 4.]]), Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[2., 0., 2., 0., 2., 0., 2., 0.],
# [0., 2., 0., 2., 0., 2., 0., 2.]]))

Examples 3:
.. code-block:: python

import paddle

x = paddle.ones(shape=(4, 2), dtype='float64')
weight = paddle.ones(shape=(2, 4), dtype='float64')
y = paddle.ones(shape=(4, 2), dtype='float64')

def func(x, y):
return x * y

x.stop_gradient = False
y.stop_gradient = False
batch_jacobian = paddle.autograd.batch_jacobian(func, [x, y])
print(batch_jacobian)
# (Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[1., 0., 1., 0., 1., 0., 1., 0.],
# [0., 1., 0., 1., 0., 1., 0., 1.]]), Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[1., 0., 1., 0., 1., 0., 1., 0.],
# [0., 1., 0., 1., 0., 1., 0., 1.]]))

'''
inputs = _tensors(inputs, "inputs")
outputs = _tensors(func(*inputs), "outputs")
batch_size = inputs[0].shape[0]
for input in inputs:
assert input.shape[
0] == batch_size, "The first dimension of input should equals to the same batch size!"
for output in outputs:
assert output.shape[
0] == batch_size, "The first dimension of output should equals to the same batch size!"
fin_size = len(inputs)
fout_size = len(outputs)
flat_outputs = tuple(
reshape(
output, shape=[batch_size, -1]) for output in outputs)
jacobian = tuple()
for i, flat_output in enumerate(flat_outputs):
jac_i = list([] for _ in range(fin_size))
for k in range(flat_output.shape[1]):
row_k = grad(
flat_output[:, k],
inputs,
create_graph=create_graph,
retain_graph=True,
allow_unused=allow_unused)
for j in range(fin_size):
jac_i[j].append(
reshape(
row_k[j], shape=[-1])
if isinstance(row_k[j], paddle.Tensor) else None)
jacobian += (tuple(
_stack_tensor_or_return_none(jac_i_j) for jac_i_j in jac_i), )
if fin_size == 1 and fout_size == 1:
return jacobian[0][0]
elif fin_size == 1 and fout_size != 1:
return tuple(jacobian[i][0] for i in range(fout_size))
elif fin_size != 1 and fout_size == 1:
return jacobian[0]
else:
return jacobian


@framework.dygraph_only
def batch_hessian(func, inputs, create_graph=False, allow_unused=False):
'''
.. note::
**This API is ONLY available in the imperative mode.**

This function computes the batch Hessian matrix of `func` with respect to `inputs`.
Noted that the first dimension of inputs is batch size.

Parameters:
func (function): a Python function that takes a Tensor or a Tensor
list/tuple as inputs(the first dimension is batch size) and
returns a Tensor with shape [batch_size, 1].
inputs (Tensor|list(Tensor)|tuple(Tensor)): the input Tensor or
Tensor list/tuple of the function ``func``.
Noted that the first dimension of inputs is batch size.
create_graph (bool, optional): whether to create the gradient graphs
of the computing process. When it is True, higher order derivatives
are supported to compute; when it is False, the gradient graphs of
the computing process would be discarded. Defaults to ``False``.
allow_unused (bool, optional): whether to raise error or return None if
some Tensors of `inputs` are unreachable in the graph. Error would
be raised if allow_unused=False, and None would be returned as
their gradients if allow_unused=True. Default False.
Returns:
Hessian (Tensor or a tuple of tuple of Tensors): if function ``func``
takes a Tensor as ``inputs``, Hessian will be a single Tensor containing
the Hessian matrix for the linearized ``inputs`` Tensor. If function
``func`` takes a Tensor list/tuple as ``inputs``, then the Hessian will
be a tuple of tuple of Tensors. Noted that the first dimension of inputs
is batch size and the execution step is to obtain the result of the
first order differentiation, and then differentiate the batch input.

For example,
the inputs shape and outputs shape of function ``func` is [batch_size, num]
and [batch_size, 1] respectively, then the batched Hessian will be a Tensor with
a shape of [num, batch_size * num].

Why the final shape in this case is that?
because batch_hessian will create a inner func(the wrapper of paddle.grad() func)
to computes the sum of gradients of `outputs` with respect to each `inputs`,
this inner func will get the first order differentiation and shape is [batch_size, num],
then call batch_jacobian to compute jacobian between the first order differentiation
and the origin inputs. The final result ``Hessian[i][j]`` will contain the Jacobian
matrix of the ``i``th column output(Noted that this output means the first order
differentiation) and the ``j``th input and will have same dtype and device as the
corresponding input. Other situations can be deduced by analogy.


Examples 1:
.. code-block:: python

import paddle

x = paddle.ones(shape=(4, 2), dtype='float64')
weight = paddle.ones(shape=(2, 4), dtype='float64')
y = paddle.ones(shape=(4, 2), dtype='float64')

def func(x):
return paddle.matmul(x * x, weight)[:, 0:1]


x.stop_gradient = False
batch_hessian = paddle.autograd.batch_hessian(func, x)
print(batch_hessian)
# Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[2., 0., 2., 0., 2., 0., 2., 0.],
# [0., 2., 0., 2., 0., 2., 0., 2.]])

Examples 2:
.. code-block:: python

import paddle

x = paddle.ones(shape=(4, 2), dtype='float64')
weight = paddle.ones(shape=(2, 4), dtype='float64')
y = paddle.ones(shape=(4, 2), dtype='float64')

def func(x, y):
return paddle.matmul(x * x * y * y, weight)[:, 0:1]

x.stop_gradient = False
y.stop_gradient = False
batch_hessian = paddle.autograd.batch_hessian(func, [x, y])
print(batch_hessian)
# ((Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[2., 0., 2., 0., 2., 0., 2., 0.],
# [0., 2., 0., 2., 0., 2., 0., 2.]]),
# Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[4., 0., 4., 0., 4., 0., 4., 0.],
# [0., 4., 0., 4., 0., 4., 0., 4.]])),
# (Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[4., 0., 4., 0., 4., 0., 4., 0.],
# [0., 4., 0., 4., 0., 4., 0., 4.]]),
# Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[2., 0., 2., 0., 2., 0., 2., 0.],
# [0., 2., 0., 2., 0., 2., 0., 2.]])))


Examples 3:
.. code-block:: python

import paddle

x = paddle.ones(shape=(4, 2), dtype='float64')
weight = paddle.ones(shape=(2, 4), dtype='float64')
y = paddle.ones(shape=(4, 2), dtype='float64')

def func(x, y):
return paddle.matmul(x * x, weight)[:, 0:1]

x.stop_gradient = False
y.stop_gradient = False
batch_hessian = paddle.autograd.batch_hessian(func, [x, y], allow_unused=True)
print(batch_hessian)
# ((Tensor(shape=[2, 8], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
# [[2., 0., 2., 0., 2., 0., 2., 0.],
# [0., 2., 0., 2., 0., 2., 0., 2.]]), None), (None, None))

'''
inputs = _tensors(inputs, "inputs")
outputs = func(*inputs)
batch_size = inputs[0].shape[0]
for input in inputs:
assert input.shape[
0] == batch_size, "The first dimension of input should equals to the same batch size!"
assert isinstance(outputs, paddle.Tensor) and outputs.shape == [
batch_size, 1
], "The function to compute batched Hessian matrix should return a Tensor of shape [batch_size, 1]"

def jac_func(*ins):
grad_inputs = grad(
outputs,
ins,
create_graph=True,
retain_graph=True,
allow_unused=allow_unused)
return tuple(
_replace_none_with_zero_tensor(grad_inputs[i], inputs[i])
for i in range(len(inputs)))

return batch_jacobian(
jac_func, inputs, create_graph=create_graph, allow_unused=allow_unused)


@framework.dygraph_only
def hessian(func, inputs, create_graph=False, allow_unused=False):
'''
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach(TEST_OP)

set_tests_properties(test_jacobian PROPERTIES TIMEOUT 20)
set_tests_properties(test_jacobian PROPERTIES TIMEOUT 50)
set_tests_properties(test_hessian PROPERTIES TIMEOUT 50)
set_tests_properties(test_vhp PROPERTIES TIMEOUT 50)
Loading