Skip to content

[ONNX] Add support for asymmetric padding for Onnx.AveragePool op #3923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/torch-mlir/Conversion/TorchToLinalg/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ Value getOutputDimForConvOps(OpBuilder &b, Location loc, Value in,
Value kernelSizeInt, Value strideInt,
bool ceilMode = false);

// Helper function to caculate the output tensor dims for pooling-like ops.
// Along each dim:
// dim_out =
// floor((dim_in + totalPadding - dilation * (kernelSize - 1) - 1) / stride) +
// 1
Value getOutputDimForPoolOps(OpBuilder &b, Location loc, Value in,
int64_t totalPadding, int64_t leftPadding,
Value dilationInt, Value kernelSizeInt,
Value strideInt, bool ceilMode);

// As above but for transposed convolution ops
// Along each dim:
// dim_out =
Expand Down
156 changes: 81 additions & 75 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,107 +456,113 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
patterns.onOp(
"AveragePool", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
std::string autoPad;
SmallVector<int64_t> dilations;
if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET"))
return failure();
if (autoPad != "NOTSET") {
// TODO: Add support for `auto_pad` != "NOTSET"
return rewriter.notifyMatchFailure(
binder.op, "unsupported conversion: auto_pad != NOTSET");
}

Torch::ValueTensorType resultType;
Value operand;
bool ceilMode, countIncludePad;
std::string autoPad;
if (binder.tensorOperand(operand) ||
binder.s64BoolAttr(ceilMode, "ceil_mode", false) ||
binder.s64BoolAttr(countIncludePad, "count_include_pad", false) ||
binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET") ||
binder.tensorResultType(resultType))
return failure();
return rewriter.notifyMatchFailure(
binder.op, "operand/ceil_mode/count_include_pad/auto_pad/"
"resultType bind failure");

// Determine the rank of input tensor.
std::optional<unsigned> maybeRank = Torch::getTensorRank(operand);
if (!maybeRank)
return rewriter.notifyMatchFailure(binder.op,
"Unimplemented: unranked tensor");
unsigned rank = *maybeRank;

SmallVector<int64_t> kernel, padding, strides;
if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) {
return failure();
}
if (kernel.size() != rank - 2) {
int64_t spatialRank = rank - 2;
SmallVector<int64_t> kernel, padding, strides, dilations;

if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {}))
return rewriter.notifyMatchFailure(binder.op,
"kernel_shape bind failure");
if (kernel.size() != static_cast<size_t>(spatialRank))
return rewriter.notifyMatchFailure(
binder.op, "kernel list size does not match the number of axes");
}
SmallVector<int64_t> defaultPadding(2 * (rank - 2), 0);
if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) {
return failure();
}
if (padding.size() != 2 * (rank - 2)) {

if (binder.s64IntegerArrayAttr(padding, "pads", {}))
return rewriter.notifyMatchFailure(binder.op, "pads bind failure");
if (!padding.empty() &&
padding.size() != static_cast<size_t>(2 * spatialRank))
return rewriter.notifyMatchFailure(
binder.op,
"padding list size does not match twice the number of axes");
}
if (binder.s64IntegerArrayAttr(
strides, "strides", llvm::SmallVector<int64_t>(rank - 2, 1))) {
return failure();
}
if (strides.size() != 1 && strides.size() != rank - 2) {
binder.op, "padding list must contain (begin,end) pair for each "
"spatial axis");

if (binder.s64IntegerArrayAttr(strides, "strides", {}))
return rewriter.notifyMatchFailure(binder.op, "strides bind failure");
if (!strides.empty() &&
strides.size() != static_cast<size_t>(spatialRank))
return rewriter.notifyMatchFailure(
binder.op, "strides list size does not match the number of axes");
}

SmallVector<Value> cstKernel, cstPadding, cstStridesDilations;
for (int64_t i : kernel) {
cstKernel.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
}
// Onnx pads format: [x1_begin, x2_begin…x1_end, x2_end,…]
// Pytorch pads format: [x1, x2,...] or [x], assume begin==end for all
// axes x.
int64_t paddingSizeHalf = padding.size() / 2;
for (int64_t i = 0; i < paddingSizeHalf; ++i) {
// Check if onnx padding attribute is symmetric.
if (padding[i] != padding[i + paddingSizeHalf])
return rewriter.notifyMatchFailure(
binder.op, "onnx padding attribute is not symmetric");
cstPadding.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(padding[i])));
if (binder.s64IntegerArrayAttr(dilations, "dilations", {}))
return rewriter.notifyMatchFailure(binder.op,
"dilations bind failure");

// set default values for padding, strides, and dilations.
if (padding.empty())
padding.resize(spatialRank, 0);
if (strides.empty())
strides.resize(spatialRank, 1);
if (dilations.empty())
dilations.resize(spatialRank, 1);

// Padding for the beginning and ending along each spatial axis, it can
// take any value greater than or equal to 0. The value represent the
// number of pixels added to the beginning and end part of the
// corresponding axis. pads format should be as follow [x1_begin,
// x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
// at the beginning of axis i and xi_end, the number of pixels added at
// the end of axis i.
auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType());
if (autoPad != "NOTSET" && autoPad != "VALID") {
const bool isSameLower = autoPad == "SAME_LOWER";
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
padding.resize_for_overwrite(2 * spatialRank);
for (unsigned dimIdx = 0; dimIdx < spatialRank; dimIdx++) {
const int64_t dilatedKernelSize =
dilations[dimIdx] * (kernel[dimIdx] - 1) + 1;
int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) /
strides[dimIdx] -
1) *
strides[dimIdx] +
dilatedKernelSize - inputShape[dimIdx + 2];
totalPad = totalPad >= 0 ? totalPad : 0;
padding[dimIdx] =
isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2);
padding[spatialRank + dimIdx] = totalPad - padding[dimIdx];
}
}
for (int64_t i : strides) {
cstStridesDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));

// If the padding is symmetric then we don't need seperate low/high
// padding values.
if (padding.size() == static_cast<size_t>(2 * spatialRank)) {
bool equal = true;
for (int i = 0; i < spatialRank; ++i) {
equal = equal && (padding[i] == padding[i + spatialRank]);
}
if (equal)
padding.resize(spatialRank);
}

// No dilations attribute in pytorch avgpool op, so use this trick to
// encode dilation into strides. Then in the following torchtolinalg
// lowering, decode strides into strides + dilation.
// Since the PyTorch AvgPool op does not contain the `dilation` arg,
// hence we use the trick of encoding dilation into strides. Then,
// during the torch->linalg lowering of the `AvgPool` op we decode the
// `strides` arg into strides values followed by dilation like:
// [strideDim1,strideDim2,...,dilationDim1,dilationDim2,...]
if (binder.s64IntegerArrayAttr(
dilations, "dilations",
llvm::SmallVector<int64_t>(rank - 2, 1))) {
return failure();
}
for (auto dilation : dilations) {
cstStridesDilations.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(dilation)));
}
SmallVector<int64_t> stridesDilations = strides;
stridesDilations.append(dilations);

Value kernelSizeList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstKernel);
Value paddingList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstPadding);
Value kernelSizeList = createConstantIntList(binder, rewriter, kernel);
Value paddingList = createConstantIntList(binder, rewriter, padding);
Value stridesDilationsList =
rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(
Torch::IntType::get(binder.op->getContext())),
cstStridesDilations);
createConstantIntList(binder, rewriter, stridesDilations);
Value cstCeilMode =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), ceilMode);
Value cstCountIncludePad = rewriter.create<Torch::ConstantBoolOp>(
Expand Down
14 changes: 10 additions & 4 deletions lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,24 @@ computeOutputTensor(Operation *op, ConversionPatternRewriter &rewriter,
Value N = getDimOp(rewriter, loc, self, 0);
Value C = getDimOp(rewriter, loc, self, 1);

SmallVector<Value> paddingIntValues =
getAsConstantIntValues(rewriter, loc, paddingInts);
SmallVector<Value> dilationIntValues =
getAsConstantIntValues(rewriter, loc, dilationInts);
SmallVector<Value> strideIntValues =
getAsConstantIntValues(rewriter, loc, strideInts);

// Get dimension size for each dimension and calculate output size
for (int64_t i = dimensionality - 1; i > -1; --i) {
// In case of asymmetric padding the total padding value would be the sum of
// low and high padding. And, in case of symmetric padding it would just be
// the double of padding value for the corresponding dimension.
int64_t totalPadding = paddingInts[i] * 2;
if ((int64_t)paddingInts.size() == 2 * dimensionality)
totalPadding = paddingInts[i] + paddingInts[i + dimensionality];

Value dimSize = getDimOp(rewriter, loc, self, i + 2);
Value outDim = torch_to_linalg::getOutputDimForConvOps(
rewriter, loc, dimSize, paddingIntValues[i], dilationIntValues[i],
Value outDim = torch_to_linalg::getOutputDimForPoolOps(
rewriter, loc, dimSize, /*totalPadding=*/totalPadding,
/*leftPadding=*/paddingInts[i], dilationIntValues[i],
kernelSizeIntValues[i], strideIntValues[i], ceilMode);
outTensorShape.insert(outTensorShape.begin(), {outDim});
}
Expand Down
47 changes: 47 additions & 0 deletions lib/Conversion/TorchToLinalg/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,53 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc,
return castIntToIndex(b, loc, out);
}

Value torch_to_linalg::getOutputDimForPoolOps(OpBuilder &b, Location loc,
Value in, int64_t totalPadding,
int64_t leftPadding,
Value dilationInt,
Value kernelSizeInt,
Value strideInt, bool ceilMode) {
Value c1 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1));
Value totalPaddingIntCst =
b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(totalPadding));

// in + totalPadding
Value inAddTotalPadding = b.createOrFold<arith::AddIOp>(
loc, castIndexToInt64(b, loc, in), totalPaddingIntCst);

// dilation * (kernelSize - 1)
Value kernelSizeSub1 = b.createOrFold<arith::SubIOp>(loc, kernelSizeInt, c1);
Value dilationTimesKernelSize =
b.createOrFold<arith::MulIOp>(loc, dilationInt, kernelSizeSub1);

Value temp = b.createOrFold<arith::SubIOp>(loc, inAddTotalPadding,
dilationTimesKernelSize);
Value dividend = b.createOrFold<arith::SubIOp>(loc, temp, c1);
Value division;
if (ceilMode)
division = b.createOrFold<arith::CeilDivSIOp>(loc, dividend, strideInt);
else
division = b.createOrFold<arith::FloorDivSIOp>(loc, dividend, strideInt);
Value out = b.createOrFold<arith::AddIOp>(loc, division, c1);

if (!ceilMode)
return castIntToIndex(b, loc, out);

Value outMinusOneTimesStride =
b.createOrFold<arith::MulIOp>(loc, division, strideInt);
Value leftPaddingIntCst =
b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(leftPadding));
Value inAddLeftPadding = b.createOrFold<arith::AddIOp>(
loc, castIndexToInt64(b, loc, in), leftPaddingIntCst);

auto reduceOutputDimCond = b.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::uge, outMinusOneTimesStride, inAddLeftPadding);

auto reducedDim =
b.createOrFold<arith::SelectOp>(loc, reduceOutputDimCond, division, out);
return castIntToIndex(b, loc, reducedDim);
}

Value torch_to_linalg::getOutputDimForConvTransposeOps(
OpBuilder &b, Location loc, Value in, Value paddingInt, Value dilationInt,
Value kernelSizeInt, Value strideInt, Value outputPaddingInt) {
Expand Down
10 changes: 10 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,16 @@ func.func @test_averagepool_with_padding(%arg0: !torch.vtensor<[1,20,64,48],f32>

// -----

// CHECK-LABEL: @test_averagepool_with_asymmetric_padding
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,1024,6,6],f32>
func.func @test_averagepool_with_asymmetric_padding(%arg1: !torch.vtensor<[1,1024,6,6],f32>) -> !torch.vtensor<[1,1024,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.opset_versions = {ai.onnx.contrib = 1000 : si64, ai.onnx.ml = 3 : si64, ai.onnx.preview.training = 1 : si64, ai.onnx.training = 1 : si64, com.microsoft = 1 : si64, com.microsoft.experimental = 1 : si64, com.microsoft.nchwc = 1 : si64, com.ms.internal.nhwc = 1 : si64, org.pytorch.aten = 1 : si64}, torch.onnx_meta.producer_name = "onnx.quantize", torch.onnx_meta.producer_version = "0.1.0"} {
%1 = torch.operator "onnx.AveragePool"(%arg1) {torch.onnx.auto_pad = "NOTSET", torch.onnx.ceil_mode = 0 : si64, torch.onnx.count_include_pad = 0 : si64, torch.onnx.kernel_shape = [7 : si64, 7 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[1,1024,6,6],f32>) -> !torch.vtensor<[1,1024,1,1],f32>
// CHECK: torch.aten.avg_pool2d %[[ARG]], {{.*}} : !torch.vtensor<[1,1024,6,6],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1024,1,1],f32>
return %1 : !torch.vtensor<[1,1024,1,1],f32>
}

// -----

// CHECK-LABEL: @test_conv_with_strides_no_padding
func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,3,2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C0:.*]] = torch.constant.int 0
Expand Down
Loading