Skip to content

Commit 301bd87

Browse files
authored
Prepack conv weights (#31)
Currently adopt prepacking conv weights in module.to Pros: thread-safe more explicit, compared to implicitly prepack at runtime Cons: no input info. meaning that queried format might not be optimal. This is not manifest in conv but this could be a problem for linear weights. unable to re-pack conv weights if it's been reordered back to plain (maybe even not a con)
1 parent 13c3dcb commit 301bd87

File tree

13 files changed

+160
-40
lines changed

13 files changed

+160
-40
lines changed

intel_pytorch_extension_py/ops/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
from .reshape import *
66
from .mlp import *
77
from .linear_fuse_relu import *
8-
8+
from .module import *
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import torch
2+
import _torch_ipex as core
3+
4+
5+
orig_module_to = torch.nn.Module.to
6+
7+
def module_to(self, *args, **kwargs):
8+
def prepack(m):
9+
if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv3d):
10+
core.prepack_conv_weight(m.weight, m.padding, m.stride, m.dilation, m.groups)
11+
12+
def prepack_reccur(m):
13+
prepack(m)
14+
for _, sub_m in m.named_children():
15+
prepack_reccur(sub_m)
16+
17+
m = orig_module_to(self, *args, **kwargs)
18+
19+
device = torch._C._nn._parse_to(*args, **kwargs)[0]
20+
if device and device.type == 'dpcpp':
21+
prepack_reccur(m)
22+
23+
return m
24+
25+
26+
torch.nn.Module.to = module_to

torch_ipex/csrc/aten_ipex_bridge.cpp

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
#include "ipex_tensor_impl.h"
1313
#include "ipex_sparse_tensor_impl.h"
14-
#include "cpu/dbl/Common.h"
1514
#include "cpu/ShadeDataContext.h"
1615
#include "cpu/bf16/Converter.h"
1716
#include "utils.h"
@@ -105,6 +104,41 @@ void reorderDilTensorToPublic(const at::Tensor& ipexTensor) {
105104
}
106105
}
107106

107+
void reorderDilTensorGeneric(const at::Tensor& ipexTensor, const dil::tensor::desc& dstDesc) {
108+
// ipexTensor is not required to be a DIL tensor
109+
dil::tensor src = cpu::dbl::comm::try_gen_dil_tensor(ipexTensor);
110+
dil::tensor dst {dstDesc};
111+
dst.feed_from(src);
112+
113+
cpu::ShadeDataContext *new_shade_data_context = cpu::ShadeDataContext::allocShadeDataContext();
114+
new_shade_data_context->data_type = cpu::SHADE_DATA_TYPE::DIL;
115+
new_shade_data_context->dil_tensor = dst;
116+
117+
if (dstDesc.is_plain()) {
118+
// Share with DNNL raw data because it is plain format now
119+
new_shade_data_context->cpu_raw_data = dst.get_data_handle();
120+
// Cannot free CPU data because the the data is owned by DNNL
121+
new_shade_data_context->cpu_del_fun = &(c10::detail::deleteNothing);
122+
} else {
123+
// If tensor is of blocked format, cpu raw data means nothing here.
124+
new_shade_data_context->cpu_raw_data = nullptr;
125+
new_shade_data_context->cpu_del_fun = nullptr;
126+
}
127+
128+
// Create a new DataPtr instances because the DataPtr class does not support set
129+
// its data or context directly
130+
c10::DataPtr shade_data_ptr(
131+
new_shade_data_context->cpu_raw_data,
132+
new_shade_data_context,
133+
&(cpu::ShadeDataContext::freeShadeDataContext),
134+
ipexTensor.device().type());
135+
136+
ipexTensor.unsafeGetTensorImpl()->storage().set_data_ptr(std::move(shade_data_ptr));
137+
138+
if (dstDesc.is_plain()) {
139+
cpu::dbl::comm::sync_shape_from_dil_to_aten(ipexTensor, dst);
140+
}
141+
}
108142

109143
void attachShadeDataContext(const at::Tensor& tensor) {
110144
auto tensor_storage_impl = tensor.storage().unsafeGetStorageImpl();

torch_ipex/csrc/aten_ipex_bridge.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <ATen/Device.h>
44
#include <ATen/Functions.h>
55
#include <ATen/Tensor.h>
6+
#include "cpu/dbl/Common.h"
67

78
#include <vector>
89

@@ -22,6 +23,13 @@ void attachShadeDataContext(const at::Tensor& tensor);
2223
*/
2324
void reorderDilTensorToPublic(const at::Tensor& ipexTensor);
2425

26+
/**
27+
* Reorder to a DNNL tensor with specified descriptor no matter input tensor is a DNNL tensor or not
28+
*
29+
* @param[in] ipexTensor The input tensor to be reordered to the spcified DNNL descriptor
30+
*/
31+
void reorderDilTensorGeneric(const at::Tensor& ipexTensor, const dil::tensor::desc& dstDesc);
32+
2533
/**
2634
* Reorder the input tensor to the specified scalar type. It is an optimized version for
2735
* DNNL OP. It means that if DNNL supports current OP, you should call this API. Otherwise, you

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -859,7 +859,7 @@ at::Tensor AtenIpexCPUDev::dil_adaptive_avg_pool2d(
859859
DEBUG("AtenIpexCPUDev::dil_adaptive_avg_pool2d\n");
860860
CHECK_DNNL_OP_PRE_COND(input);
861861
auto output_size_vec =
862-
dbl::pool::expand_param_if_needed(output_size, "output_size", input.dim() - 2);
862+
dbl::comm::expand_param_if_needed(output_size, "output_size", input.dim() - 2);
863863
std::vector<int64_t> kernel_size(input.dim() - 2);
864864
for (int64_t i = 2; i < input.dim(); ++i) {
865865
auto s1 = input.size(i);

torch_ipex/csrc/cpu/FusionOPs.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
namespace torch_ipex {
2121
namespace cpu {
2222

23+
using namespace dbl::comm;
24+
2325
at::Tensor AtenIpexJITDev::dil_convolution_relu(
2426
const at::Tensor & input,
2527
const at::Tensor & weight,
@@ -35,11 +37,11 @@ at::Tensor AtenIpexJITDev::dil_convolution_relu(
3537
auto input_contiguous = input.contiguous();
3638
auto weight_contiguous = weight.contiguous();
3739

38-
dil_input = dbl::comm::try_gen_dil_tensor(input_contiguous);
39-
dil_weight = dbl::comm::try_gen_dil_tensor(weight_contiguous);
40+
dil_input = try_gen_dil_tensor(input_contiguous);
41+
dil_weight = try_gen_dil_tensor(weight_contiguous);
4042
if (bias.defined()) {
4143
auto bias_contiguous = bias.contiguous();
42-
dil_bias = dbl::comm::try_gen_dil_tensor(bias_contiguous);
44+
dil_bias = try_gen_dil_tensor(bias_contiguous);
4345
}
4446

4547
dil::tensor dil_output = dbl::conv::conv2d_impl(
@@ -52,7 +54,7 @@ at::Tensor AtenIpexJITDev::dil_convolution_relu(
5254
groups,
5355
dil::attr_t::fuse_relu());
5456

55-
return dbl::comm::gen_aten_tensor_by(std::move(dil_output));
57+
return gen_aten_tensor_by(std::move(dil_output));
5658
}
5759

5860
static at::Tensor& dil_convolution_inplace_fusion(
@@ -74,12 +76,12 @@ static at::Tensor& dil_convolution_inplace_fusion(
7476
auto weight_contiguous = weight.contiguous();
7577
auto output_contiguous = accumu.contiguous();
7678

77-
dil_input = dbl::comm::try_gen_dil_tensor(input_contiguous);
78-
dil_weight = dbl::comm::try_gen_dil_tensor(weight_contiguous);
79-
dil_output = dbl::comm::try_gen_dil_tensor(output_contiguous);
79+
dil_input = try_gen_dil_tensor(input_contiguous);
80+
dil_weight = try_gen_dil_tensor(weight_contiguous);
81+
dil_output = try_gen_dil_tensor(output_contiguous);
8082
if (bias.defined()) {
8183
auto bias_contiguous = bias.contiguous();
82-
dil_bias = dbl::comm::try_gen_dil_tensor(bias_contiguous);
84+
dil_bias = try_gen_dil_tensor(bias_contiguous);
8385
}
8486

8587
dbl::conv::conv2d_inplace_impl(
@@ -93,7 +95,7 @@ static at::Tensor& dil_convolution_inplace_fusion(
9395
groups,
9496
attr);
9597

96-
dbl::comm::sync_shape_from_dil_to_aten(accumu, dil_output);
98+
sync_shape_from_dil_to_aten(accumu, dil_output);
9799
return accumu;
98100
}
99101

torch_ipex/csrc/cpu/Prepack.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#include "Prepack.h"
2+
#include "dbl/Common.h"
3+
#include "torch_ipex/csrc/aten_ipex_bridge.h"
4+
#include "torch_ipex/csrc/utils.h"
5+
6+
namespace torch_ipex {
7+
8+
using namespace cpu::dbl::comm;
9+
10+
void AtenIpexPrepack::prepack_conv_weight(
11+
at::Tensor &weight,
12+
at::IntArrayRef padding,
13+
at::IntArrayRef stride,
14+
at::IntArrayRef dilation,
15+
int64_t groups) {
16+
TORCH_CHECK(weight.device().type() == at::DeviceType::DPCPP,
17+
"Cannot prepack a non-dpcpp tensor. Call t.to('dpcpp') first.");
18+
19+
auto kdims = weight.dim() - 2;
20+
auto stride_vec = expand_param_if_needed(stride, "stride", kdims);
21+
auto padding_vec = expand_param_if_needed(padding, "padding", kdims);
22+
auto dilation_vec = expand_param_if_needed(dilation, "dilation", kdims);
23+
24+
auto packed_desc =
25+
dil::convolution_forward::expected_weights_desc(
26+
weight.sizes().vec(),
27+
torch_ipex::get_dil_data_type(weight.scalar_type()),
28+
stride_vec,
29+
padding_vec,
30+
padding_vec,
31+
dilation_vec,
32+
groups);
33+
34+
bridge::reorderDilTensorGeneric(weight, packed_desc);
35+
}
36+
37+
} // namespace torch_ipex

torch_ipex/csrc/cpu/Prepack.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#pragma once
2+
3+
#include <ATen/Tensor.h>
4+
5+
namespace torch_ipex {
6+
7+
class AtenIpexPrepack {
8+
public:
9+
static void prepack_conv_weight(at::Tensor &weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups);
10+
};
11+
12+
} // namespace torch_ipex

torch_ipex/csrc/cpu/dbl/Common.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ at::Tensor gen_aten_tensor_by(dil::tensor&& dil_tensor) {
7373
nullptr,
7474
/*resizeable=*/false);
7575
auto _tensor = at::detail::make_tensor<torch_ipex::IPEXTensorImpl>(storage_impl, at::DispatchKey::DPCPPTensorId);
76-
dbl::comm::sync_shape_from_dil_to_aten(_tensor, shade_data_context->dil_tensor.value());
76+
sync_shape_from_dil_to_aten(_tensor, shade_data_context->dil_tensor.value());
7777
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(_tensor.layout() == c10::kStrided);
7878
return _tensor;
7979
}
@@ -101,6 +101,23 @@ void sync_shape_from_dil_to_aten(const at::Tensor& ipex_tensor, const dil::tenso
101101
}
102102
}
103103

104+
std::vector<int64_t> expand_param_if_needed(
105+
at::IntArrayRef list_param,
106+
const char* param_name,
107+
int64_t expected_dim) {
108+
if (list_param.size() == 1) {
109+
return std::vector<int64_t>(expected_dim, list_param[0]);
110+
} else if ((int64_t)list_param.size() != expected_dim) {
111+
std::ostringstream ss;
112+
ss << "expected " << param_name << " to be a single integer value or a "
113+
<< "list of " << expected_dim << " values to match the convolution "
114+
<< "dimensions, but got " << param_name << "=" << list_param;
115+
AT_ERROR(ss.str());
116+
} else {
117+
return list_param.vec();
118+
}
119+
}
120+
104121
} // namespace comm
105122
} // namespace dbl
106123
} // namespace cpu

torch_ipex/csrc/cpu/dbl/Common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ dil::tensor try_gen_dil_tensor(const at::Tensor &input);
1515
at::Tensor gen_aten_tensor_by(dil::tensor&& tensor);
1616
at::Tensor empty_dil_tensor(at::IntArrayRef sizes, const at::TensorOptions& options);
1717
void sync_shape_from_dil_to_aten(const at::Tensor& ipex_tensor, const dil::tensor &dil_tensor);
18+
std::vector<int64_t> expand_param_if_needed(
19+
at::IntArrayRef list_param, const char *param_name, int64_t expected_dim);
1820

1921
} // namespace comm
2022
} // namespace dbl

0 commit comments

Comments
 (0)