Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 153 additions & 99 deletions python/paddle/fluid/tests/unittests/test_l1_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,111 +20,165 @@
import unittest


class TestL1Loss(unittest.TestCase):
def test_L1Loss_mean(self):
input_np = np.random.random(size=(10, 1)).astype(np.float32)
label_np = np.random.random(size=(10, 1)).astype(np.float32)
prog = fluid.Program()
startup_prog = fluid.Program()
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
with fluid.program_guard(prog, startup_prog):
input = fluid.layers.data(
name='input', shape=[10, 1], dtype='float32')
label = fluid.layers.data(
name='label', shape=[10, 1], dtype='float32')
l1_loss = paddle.nn.loss.L1Loss()
ret = l1_loss(input, label)

exe = fluid.Executor(place)
static_result = exe.run(
prog,
feed={"input": input_np,
"label": label_np},
fetch_list=[ret])

with fluid.dygraph.guard():
l1_loss = paddle.nn.loss.L1Loss()
dy_ret = l1_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
dy_result = dy_ret.numpy()

expected = np.mean(np.abs(input_np - label_np))
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
class TestFunctionalL1Loss(unittest.TestCase):
def setUp(self):
self.input_np = np.random.random(size=(10, 10, 5)).astype(np.float32)
self.label_np = np.random.random(size=(10, 10, 5)).astype(np.float32)

def run_imperative(self):
input = paddle.to_variable(self.input_np)
label = paddle.to_variable(self.label_np)
dy_result = paddle.nn.functional.l1_loss(input, label)
expected = np.mean(np.abs(self.input_np - self.label_np))
self.assertTrue(np.allclose(dy_result.numpy(), expected))
self.assertTrue(dy_result.shape, [1])

def test_L1Loss_sum(self):
input_np = np.random.random(size=(10, 10, 5)).astype(np.float32)
label_np = np.random.random(size=(10, 10, 5)).astype(np.float32)
prog = fluid.Program()
startup_prog = fluid.Program()
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
with fluid.program_guard(prog, startup_prog):
input = fluid.layers.data(
dy_result = paddle.nn.functional.l1_loss(input, label, reduction='sum')
expected = np.sum(np.abs(self.input_np - self.label_np))
self.assertTrue(np.allclose(dy_result.numpy(), expected))
self.assertTrue(dy_result.shape, [1])

dy_result = paddle.nn.functional.l1_loss(input, label, reduction='none')
expected = np.abs(self.input_np - self.label_np)
self.assertTrue(np.allclose(dy_result.numpy(), expected))
self.assertTrue(dy_result.shape, [10, 10, 5])

def run_static(self, use_gpu=False):
input = paddle.data(name='input', shape=[10, 10, 5], dtype='float32')
label = paddle.data(name='label', shape=[10, 10, 5], dtype='float32')
result0 = paddle.nn.functional.l1_loss(input, label)
result1 = paddle.nn.functional.l1_loss(input, label, reduction='sum')
result2 = paddle.nn.functional.l1_loss(input, label, reduction='none')
y = paddle.nn.functional.l1_loss(input, label, name='aaa')

place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
static_result = exe.run(
feed={"input": self.input_np,
"label": self.label_np},
fetch_list=[result0, result1, result2])

expected = np.mean(np.abs(self.input_np - self.label_np))
self.assertTrue(np.allclose(static_result[0], expected))
expected = np.sum(np.abs(self.input_np - self.label_np))
self.assertTrue(np.allclose(static_result[1], expected))
expected = np.abs(self.input_np - self.label_np)
self.assertTrue(np.allclose(static_result[2], expected))

self.assertTrue('aaa' in y.name)

def test_cpu(self):
paddle.disable_static(place=paddle.fluid.CPUPlace())
self.run_imperative()
paddle.enable_static()

with fluid.program_guard(fluid.Program()):
self.run_static()

def test_gpu(self):
if not fluid.core.is_compiled_with_cuda():
return

paddle.disable_static(place=paddle.fluid.CUDAPlace(0))
self.run_imperative()
paddle.enable_static()

with fluid.program_guard(fluid.Program()):
self.run_static(use_gpu=True)

# test case the raise message
def test_errors(self):
def test_value_error():
input = paddle.data(
name='input', shape=[10, 10, 5], dtype='float32')
label = fluid.layers.data(
label = paddle.data(
name='label', shape=[10, 10, 5], dtype='float32')
l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
ret = l1_loss(input, label)

exe = fluid.Executor(place)
static_result = exe.run(
prog,
feed={"input": input_np,
"label": label_np},
fetch_list=[ret])

with fluid.dygraph.guard():
l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
dy_ret = l1_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
dy_result = dy_ret.numpy()

expected = np.sum(np.abs(input_np - label_np))
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
loss = paddle.nn.functional.l1_loss(
input, label, reduction='reduce_mean')

self.assertRaises(ValueError, test_value_error)


class TestClassL1Loss(unittest.TestCase):
def setUp(self):
self.input_np = np.random.random(size=(10, 10, 5)).astype(np.float32)
self.label_np = np.random.random(size=(10, 10, 5)).astype(np.float32)

def run_imperative(self):
input = paddle.to_variable(self.input_np)
label = paddle.to_variable(self.label_np)
l1_loss = paddle.nn.loss.L1Loss()
dy_result = l1_loss(input, label)
expected = np.mean(np.abs(self.input_np - self.label_np))
self.assertTrue(np.allclose(dy_result.numpy(), expected))
self.assertTrue(dy_result.shape, [1])

l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
dy_result = l1_loss(input, label)
expected = np.sum(np.abs(self.input_np - self.label_np))
self.assertTrue(np.allclose(dy_result.numpy(), expected))
self.assertTrue(dy_result.shape, [1])

def test_L1Loss_none(self):
input_np = np.random.random(size=(10, 5)).astype(np.float32)
label_np = np.random.random(size=(10, 5)).astype(np.float32)
prog = fluid.Program()
startup_prog = fluid.Program()
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
with fluid.program_guard(prog, startup_prog):
input = fluid.layers.data(
name='input', shape=[10, 5], dtype='float32')
label = fluid.layers.data(
name='label', shape=[10, 5], dtype='float32')
l1_loss = paddle.nn.loss.L1Loss(reduction='none')
ret = l1_loss(input, label)

exe = fluid.Executor(place)
static_result = exe.run(
prog,
feed={"input": input_np,
"label": label_np},
fetch_list=[ret])

with fluid.dygraph.guard():
l1_loss = paddle.nn.loss.L1Loss(reduction='none')
dy_ret = l1_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
dy_result = dy_ret.numpy()

expected = np.abs(input_np - label_np)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
self.assertTrue(dy_result.shape, input.shape)
l1_loss = paddle.nn.loss.L1Loss(reduction='none')
dy_result = l1_loss(input, label)
expected = np.abs(self.input_np - self.label_np)
self.assertTrue(np.allclose(dy_result.numpy(), expected))
self.assertTrue(dy_result.shape, [10, 10, 5])

def run_static(self, use_gpu=False):
input = paddle.data(name='input', shape=[10, 10, 5], dtype='float32')
label = paddle.data(name='label', shape=[10, 10, 5], dtype='float32')
l1_loss = paddle.nn.loss.L1Loss()
result0 = l1_loss(input, label)
l1_loss = paddle.nn.loss.L1Loss(reduction='sum')
result1 = l1_loss(input, label)
l1_loss = paddle.nn.loss.L1Loss(reduction='none')
result2 = l1_loss(input, label)
l1_loss = paddle.nn.loss.L1Loss(name='aaa')
result3 = l1_loss(input, label)

place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
static_result = exe.run(
feed={"input": self.input_np,
"label": self.label_np},
fetch_list=[result0, result1, result2])

expected = np.mean(np.abs(self.input_np - self.label_np))
self.assertTrue(np.allclose(static_result[0], expected))
expected = np.sum(np.abs(self.input_np - self.label_np))
self.assertTrue(np.allclose(static_result[1], expected))
expected = np.abs(self.input_np - self.label_np)
self.assertTrue(np.allclose(static_result[2], expected))
self.assertTrue('aaa' in result3.name)

def test_cpu(self):
paddle.disable_static(place=paddle.fluid.CPUPlace())
self.run_imperative()
paddle.enable_static()

with fluid.program_guard(fluid.Program()):
self.run_static()

def test_gpu(self):
if not fluid.core.is_compiled_with_cuda():
return

paddle.disable_static(place=paddle.fluid.CUDAPlace(0))
self.run_imperative()
paddle.enable_static()

with fluid.program_guard(fluid.Program()):
self.run_static(use_gpu=True)

# test case the raise message
def test_errors(self):
def test_value_error():
loss = paddle.nn.loss.L1Loss(reduction="reduce_mean")

self.assertRaises(ValueError, test_value_error)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions python/paddle/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
from .loss import huber_loss #DEFINE_ALIAS
from .loss import iou_similarity #DEFINE_ALIAS
from .loss import kldiv_loss #DEFINE_ALIAS
from .loss import l1_loss #DEFINE_ALIAS
from .loss import log_loss #DEFINE_ALIAS
from .loss import margin_rank_loss #DEFINE_ALIAS
from .loss import mse_loss #DEFINE_ALIAS
Expand Down
94 changes: 94 additions & 0 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
# limitations under the License.

# TODO: define loss functions of neural network
import paddle
import paddle.fluid as fluid
from ...fluid.framework import core, in_dygraph_mode
from ...fluid.layers.nn import _elementwise_op_in_dygraph
from ...fluid.layers import bpr_loss #DEFINE_ALIAS
from ...fluid.layers import center_loss #DEFINE_ALIAS
from ...fluid.layers import cross_entropy #DEFINE_ALIAS
Expand Down Expand Up @@ -45,6 +49,7 @@
'huber_loss',
'iou_similarity',
'kldiv_loss',
'l1_loss',
'log_loss',
'margin_rank_loss',
'mse_loss',
Expand All @@ -60,3 +65,92 @@
'ssd_loss',
'teacher_student_sigmoid_loss'
]


def l1_loss(x, label, reduction='mean', name=None):
"""
This operator computes the L1 Loss of Tensor ``x`` and ``label`` as follows.

If :attr:`reduction` set to ``'none'``, the loss is:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否需要多一行?none mean下面都空了一行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

.. math::
Out = \lvert x - label\rvert

If :attr:`reduction` set to ``'mean'``, the loss is:

.. math::
Out = MEAN(\lvert x - label\rvert)

If :attr:`reduction` set to ``'sum'``, the loss is:

.. math::
Out = SUM(\lvert x - label\rvert)


Parameters:
x (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means any number of additional dimensions. It's data type should be float32, float64, int32, int64.
label (Tensor): label. The shapes is [N, *], same shape as ``x`` . It's data type should be float32, float64, int32, int64.
reduction (str, optional): Indicate the reduction to apply to the loss,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned.
If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned.
Default is ``'mean'``.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, the L1 Loss of Tensor ``x`` and ``label``.
If :attr:`reduction` is ``'none'``, the shape of output loss is [N, *], the same as ``x`` .
If :attr:`reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1], which means the output is a scalar.
Examples:
.. code-block:: python
import paddle
import numpy as np

paddle.disable_static()
x_data = np.array([[1.5, 0.8], [0.2, 1.3]]).astype("float32")
label_data = np.array([[1.7, 1], [0.4, 0.5]]).astype("float32")
x = paddle.to_variable(x_data)
label = paddle.to_variable(label_data)

l1_loss = paddle.nn.functional.l1_loss(x, label)
print(l1_loss.numpy())
# [0.35]

l1_loss = paddle.nn.functional.l1_loss(x, label, reduction='none')
print(l1_loss.numpy())
# [[0.20000005 0.19999999]
# [0.2 0.79999995]]

l1_loss = paddle.nn.functional.l1_loss(x, label, reduction='sum')
print(l1_loss.numpy())
# [1.4]
"""
if reduction not in ['sum', 'mean', 'none']:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最前面是动态图的代码,使用core.ops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

raise ValueError(
"The value of 'reduction' in L1Loss should be 'sum', 'mean' or 'none', but "
"received %s, which is not allowed." % reduction)

if in_dygraph_mode():
unreduced = _elementwise_op_in_dygraph(
x, label, axis=-1, act='abs', op_name='elementwise_sub')
if reduction == 'mean':
return core.ops.mean(unreduced)
elif reduction == 'sum':
return core.ops.reduce_sum(unreduced, 'dim', [0], 'keep_dim', False,
'reduce_all', True)
else:
return unreduced

fluid.data_feeder.check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int32', 'int64'], 'l1_loss')
fluid.data_feeder.check_variable_and_dtype(
label, 'label', ['float32', 'float64', 'int32', 'int64'], 'l1_loss')

if reduction == 'sum':
unreduced = paddle.elementwise_sub(x, label, act='abs')
return paddle.sum(unreduced, name=name)
elif reduction == 'mean':
unreduced = paddle.elementwise_sub(x, label, act='abs')
return paddle.mean(unreduced, name=name)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

动态图是使用core.ops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return paddle.elementwise_sub(x, label, act='abs', name=name)
Loading