Skip to content

Commit 42dff72

Browse files
Yu-ZhewenIanWood1
authored andcommitted
Signed-off-by: Yu-Zhewen <zhewenyu@amd.com>
1 parent 288cd5e commit 42dff72

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -373,21 +373,21 @@ LogicalResult ClassTypeOp::verify() {
373373
// PrimLoopOp
374374
//===----------------------------------------------------------------------===//
375375

376-
OperandRange PrimLoopOp::getEntrySuccessorOperands(RegionBranchPoint point) {
377-
assert(point == getRegion());
376+
OperandRange PrimLoopOp::getEntrySuccessorOperands(RegionSuccessor successor) {
377+
assert(successor.getSuccessor() == &getRegion());
378378
return getIterArgsInit();
379379
}
380380

381381
void PrimLoopOp::getSuccessorRegions(
382382
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
383383
Region &region = getRegion();
384-
if (!point.getRegionOrNull()) {
384+
if (!point.getTerminatorPredecessorOrNull()) {
385385
regions.emplace_back(&region, region.getArguments().slice(1));
386386
return;
387387
}
388-
assert(point == region);
388+
assert(point.getTerminatorPredecessorOrNull()->getParentRegion() == &region);
389389
regions.emplace_back(&region, region.getArguments().slice(1));
390-
regions.emplace_back(getResults());
390+
regions.emplace_back(getOperation(), getResults());
391391
}
392392

393393
bool PrimLoopOp::isForLike() {
@@ -400,7 +400,7 @@ bool PrimLoopOp::isForLike() {
400400
//===----------------------------------------------------------------------===//
401401

402402
MutableOperandRange
403-
PrimLoopConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
403+
PrimLoopConditionOp::getMutableSuccessorOperands(RegionSuccessor successor) {
404404
// Pass all operands except the condition to the successor which is the
405405
// parent loop op.
406406
return getIterArgsMutable();
@@ -452,8 +452,8 @@ void PrimIfOp::print(OpAsmPrinter &p) {
452452
void PrimIfOp::getSuccessorRegions(RegionBranchPoint point,
453453
SmallVectorImpl<RegionSuccessor> &regions) {
454454
// The `then` and the `else` region branch back to the parent operation.
455-
if (point.getRegionOrNull()) {
456-
regions.push_back(RegionSuccessor(getResults()));
455+
if (point.getTerminatorPredecessorOrNull()) {
456+
regions.push_back(RegionSuccessor(getOperation(), getResults()));
457457
return;
458458
}
459459

@@ -5321,17 +5321,18 @@ template <typename CalculateOp>
53215321
static void
53225322
getSuccessorRegionsForCalculateOp(CalculateOp op, RegionBranchPoint point,
53235323
SmallVectorImpl<RegionSuccessor> &regions) {
5324-
if (!point.getRegionOrNull()) {
5324+
if (!point.getTerminatorPredecessorOrNull()) {
53255325
// First thing the op does is branch into the calculation.
53265326
regions.emplace_back(&op.getCalculation());
53275327
return;
53285328
}
5329-
if (point == op.getBody()) {
5329+
Region *region = point.getTerminatorPredecessorOrNull()->getParentRegion();
5330+
if (region == &op.getBody()) {
53305331
// Body returns control to the outer op, passing through results.
5331-
regions.emplace_back(op.getResults());
5332+
regions.emplace_back(op.getOperation(), op.getResults());
53325333
return;
53335334
}
5334-
assert(point == op.getCalculation());
5335+
assert(region == &op.getCalculation());
53355336
// Calculation branches to the body.
53365337
regions.emplace_back(&op.getBody());
53375338
}
@@ -5355,7 +5356,7 @@ void DtypeCalculateOp::getSuccessorRegions(
53555356
//===----------------------------------------------------------------------===//
53565357

53575358
MutableOperandRange ShapeCalculateYieldShapesOp::getMutableSuccessorOperands(
5358-
RegionBranchPoint point) {
5359+
RegionSuccessor successor) {
53595360
// The shape operands don't get forwarded to the body.
53605361
// MutableOperandRange always has an owning operation, even if empty, so
53615362
// create a 0-length range.
@@ -5846,7 +5847,7 @@ LogicalResult AtenKthvalueOp::verify() {
58465847
//===----------------------------------------------------------------------===//
58475848

58485849
MutableOperandRange DtypeCalculateYieldDtypesOp::getMutableSuccessorOperands(
5849-
RegionBranchPoint point) {
5850+
RegionSuccessor successor) {
58505851
// The dtype operands don't get forwarded to the body.
58515852
// MutableOperandRange always has an owning operation, even if empty, so
58525853
// create a 0-length range.

0 commit comments

Comments
 (0)