Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Add support for quantized LeakyReLU
Summary: Also adds support for backend_config Reviewed By: mcr229 Differential Revision: D47043207 fbshipit-source-id: 509bd4c02eb7ff5d3d47762522debd827bee7240
  • Loading branch information
digantdesai authored and facebook-github-bot committed Jun 28, 2023
commit e276b24cd7fc7ce13f5086a5eb3a507b973919e4
3 changes: 3 additions & 0 deletions backends/xnnpack/partition/xnnpack_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,9 @@ def __init__(self):
torch.nn.ReLU,
torch.nn.functional.relu,
torch.nn.functional.relu_,
torch.nn.functional.leaky_relu,
torch.nn.functional.leaky_relu_,
torch.nn.LeakyReLU,
]

# Modules which support dynamic quantization
Expand Down
41 changes: 41 additions & 0 deletions backends/xnnpack/test/test_xnnpack_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,47 @@ def test_xnnpack_qhardtanh(self):
example_inputs = (torch.randn(1, 1, 1),)
self.quantize_and_test_model(torch.nn.Hardtanh(), example_inputs)

def test_xnnpack_leaky_relu(self):
example_inputs = (torch.randn(1, 3, 3),)

class LeakyReLUModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.leaky_relu_out_of_place = torch.nn.LeakyReLU(negative_slope=0.2)

def forward(self, x):
return self.leaky_relu_out_of_place(x)

self.quantize_and_test_model(LeakyReLUModule(), example_inputs)

def test_xnnpack_leaky_relu2(self):
example_inputs = (torch.randn(1, 3, 3),)

class LeakyReLUModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.leaky_relu_in_place = torch.nn.LeakyReLU(
negative_slope=0.08, inplace=True
)

def forward(self, x):
return self.leaky_relu_in_place(x)

self.quantize_and_test_model(LeakyReLUModule(), example_inputs)

def test_xnnpack_leaky_relu3(self):
example_inputs = (torch.randn(1, 3, 3),)

class LeakyReLUModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.leaky_relu_functional_default = torch.nn.functional.leaky_relu

def forward(self, x):
return self.leaky_relu_functional_default(x)

self.quantize_and_test_model(LeakyReLUModule(), example_inputs)

def test_xnnpack_qlinear(self):
in_size = 1
input_size = 3
Expand Down