Skip to content

Commit c2fcaf8

Browse files
committed
[MLIR][Presburger] Simplex: refactor (symbolic)lex to support specifying multiple varKinds as symbols
This is also required to support lexmin for relations. Reviewed By: Groverkss Differential Revision: https://reviews.llvm.org/D128931
1 parent fdf1fda commit c2fcaf8

File tree

4 files changed

+80
-48
lines changed

4 files changed

+80
-48
lines changed

mlir/include/mlir/Analysis/Presburger/Simplex.h

Lines changed: 53 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/Support/LogicalResult.h"
2424
#include "llvm/ADT/ArrayRef.h"
2525
#include "llvm/ADT/Optional.h"
26+
#include "llvm/ADT/SmallBitVector.h"
2627
#include "llvm/ADT/SmallVector.h"
2728
#include "llvm/Support/StringSaver.h"
2829
#include "llvm/Support/raw_ostream.h"
@@ -210,14 +211,18 @@ class SimplexBase {
210211

211212
protected:
212213
/// Construct a SimplexBase with the specified number of variables and fixed
213-
/// columns.
214+
/// columns. The first overload should be used when there are nosymbols.
215+
/// With the second overload, the specified range of vars will be marked
216+
/// as symbols. With the third overload, `isSymbol` is a bitmask denoting
217+
/// which vars are symbols. The size of `isSymbol` must be `nVar`.
214218
///
215219
/// For example, Simplex uses two fixed columns: the denominator and the
216220
/// constant term, whereas LexSimplex has an extra fixed column for the
217221
/// so-called big M parameter. For more information see the documentation for
218222
/// LexSimplex.
219-
SimplexBase(unsigned nVar, bool mustUseBigM, unsigned symbolOffset,
220-
unsigned nSymbol);
223+
SimplexBase(unsigned nVar, bool mustUseBigM);
224+
SimplexBase(unsigned nVar, bool mustUseBigM,
225+
const llvm::SmallBitVector &isSymbol);
221226

222227
enum class Orientation { Row, Column };
223228

@@ -422,12 +427,16 @@ class LexSimplexBase : public SimplexBase {
422427
unsigned getSnapshot() { return SimplexBase::getSnapshotBasis(); }
423428

424429
protected:
425-
LexSimplexBase(unsigned nVar, unsigned symbolOffset, unsigned nSymbol)
426-
: SimplexBase(nVar, /*mustUseBigM=*/true, symbolOffset, nSymbol) {}
430+
LexSimplexBase(unsigned nVar) : SimplexBase(nVar, /*mustUseBigM=*/true) {}
431+
LexSimplexBase(unsigned nVar, const llvm::SmallBitVector &isSymbol)
432+
: SimplexBase(nVar, /*mustUseBigM=*/true, isSymbol) {}
427433
explicit LexSimplexBase(const IntegerRelation &constraints)
428-
: LexSimplexBase(constraints.getNumVars(),
429-
constraints.getVarKindOffset(VarKind::Symbol),
430-
constraints.getNumSymbolVars()) {
434+
: LexSimplexBase(constraints.getNumVars()) {
435+
intersectIntegerRelation(constraints);
436+
}
437+
explicit LexSimplexBase(const IntegerRelation &constraints,
438+
const llvm::SmallBitVector &isSymbol)
439+
: LexSimplexBase(constraints.getNumVars(), isSymbol) {
431440
intersectIntegerRelation(constraints);
432441
}
433442

@@ -470,13 +479,12 @@ class LexSimplexBase : public SimplexBase {
470479
/// provides support for integer-exact redundancy and separateness checks.
471480
class LexSimplex : public LexSimplexBase {
472481
public:
473-
explicit LexSimplex(unsigned nVar)
474-
: LexSimplexBase(nVar, /*symbolOffset=*/0, /*nSymbol=*/0) {}
482+
explicit LexSimplex(unsigned nVar) : LexSimplexBase(nVar) {}
483+
// Note that LexSimplex does NOT support symbolic lexmin;
484+
// use SymbolicLexSimplex if that is required. LexSimplex ignores the VarKinds
485+
// of the passed IntegerRelation. Symbols will be treated as ordinary vars.
475486
explicit LexSimplex(const IntegerRelation &constraints)
476-
: LexSimplexBase(constraints) {
477-
assert(constraints.getNumSymbolVars() == 0 &&
478-
"LexSimplex does not support symbols!");
479-
}
487+
: LexSimplexBase(constraints) {}
480488

481489
/// Return the lexicographically minimum rational solution to the constraints.
482490
MaybeOptimum<SmallVector<Fraction, 8>> findRationalLexMin();
@@ -521,10 +529,9 @@ class LexSimplex : public LexSimplexBase {
521529

522530
/// Represents the result of a symbolic lexicographic minimization computation.
523531
struct SymbolicLexMin {
524-
SymbolicLexMin(unsigned nSymbols, unsigned nNonSymbols)
525-
: lexmin(PresburgerSpace::getSetSpace(nSymbols), nNonSymbols),
526-
unboundedDomain(
527-
PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(nSymbols))) {}
532+
SymbolicLexMin(const PresburgerSpace &domainSpace, unsigned numOutputs)
533+
: lexmin(domainSpace, numOutputs),
534+
unboundedDomain(PresburgerSet::getEmpty(domainSpace)) {}
528535

529536
/// This maps assignments of symbols to the corresponding lexmin.
530537
/// Takes no value when no integer sample exists for the assignment or if the
@@ -569,30 +576,40 @@ class SymbolicLexSimplex : public LexSimplexBase {
569576
/// `constraints` is the set for which the symbolic lexmin will be computed.
570577
/// `symbolDomain` is the set of values of the symbols for which the lexmin
571578
/// will be computed. `symbolDomain` should have a dim var for every symbol in
572-
/// `constraints`, and no other vars.
579+
/// `constraints`, and no other vars. `isSymbol` specifies which vars of
580+
/// `constraints` should be considered as symbols.
581+
///
582+
/// The resulting SymbolicLexMin's space will be compatible with that of
583+
/// symbolDomain.
573584
SymbolicLexSimplex(const IntegerRelation &constraints,
574-
const IntegerPolyhedron &symbolDomain)
575-
: SymbolicLexSimplex(constraints,
576-
constraints.getVarKindOffset(VarKind::Symbol),
577-
symbolDomain) {
578-
assert(constraints.getNumSymbolVars() == symbolDomain.getNumVars());
585+
const IntegerPolyhedron &symbolDomain,
586+
const llvm::SmallBitVector &isSymbol)
587+
: LexSimplexBase(constraints, isSymbol), domainPoly(symbolDomain),
588+
domainSimplex(symbolDomain) {
589+
// TODO consider supporting this case. It amounts
590+
// to just returning the input constraints.
591+
assert(domainPoly.getNumVars() > 0 &&
592+
"there must be some non-symbols to optimize!");
579593
}
580594

581-
/// An overload to select some other subrange of ids as symbols for lexmin.
595+
/// An overload to select some subrange of ids as symbols for lexmin.
582596
/// The symbol ids are the range of ids with absolute index
583597
/// [symbolOffset, symbolOffset + symbolDomain.getNumVars())
584-
/// symbolDomain should only have dim ids.
585598
SymbolicLexSimplex(const IntegerRelation &constraints, unsigned symbolOffset,
586599
const IntegerPolyhedron &symbolDomain)
587-
: LexSimplexBase(/*nVar=*/constraints.getNumVars(), symbolOffset,
588-
symbolDomain.getNumVars()),
589-
domainPoly(symbolDomain), domainSimplex(symbolDomain) {
590-
// TODO consider supporting this case. It amounts
591-
// to just returning the input constraints.
592-
assert(domainPoly.getNumVars() > 0 &&
593-
"there must be some non-symbols to optimize!");
594-
assert(domainPoly.getNumVars() == domainPoly.getNumDimVars());
595-
intersectIntegerRelation(constraints);
600+
: SymbolicLexSimplex(constraints, symbolDomain,
601+
getSubrangeBitVector(constraints.getNumVars(),
602+
symbolOffset,
603+
symbolDomain.getNumVars())) {}
604+
605+
/// An overload to select the symbols of `constraints` as symbols for lexmin.
606+
SymbolicLexSimplex(const IntegerRelation &constraints,
607+
const IntegerPolyhedron &symbolDomain)
608+
: SymbolicLexSimplex(constraints,
609+
constraints.getVarKindOffset(VarKind::Symbol),
610+
symbolDomain) {
611+
assert(constraints.getNumSymbolVars() == symbolDomain.getNumVars() &&
612+
"symbolDomain must have as many vars as constraints has symbols!");
596613
}
597614

598615
/// The lexmin will be stored as a function `lexmin` from symbols to
@@ -678,9 +695,7 @@ class Simplex : public SimplexBase {
678695
enum class Direction { Up, Down };
679696

680697
Simplex() = delete;
681-
explicit Simplex(unsigned nVar)
682-
: SimplexBase(nVar, /*mustUseBigM=*/false, /*symbolOffset=*/0,
683-
/*nSymbol=*/0) {}
698+
explicit Simplex(unsigned nVar) : SimplexBase(nVar, /*mustUseBigM=*/false) {}
684699
explicit Simplex(const IntegerRelation &constraints)
685700
: Simplex(constraints.getNumVars()) {
686701
intersectIntegerRelation(constraints);

mlir/include/mlir/Analysis/Presburger/Utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "mlir/Support/LLVM.h"
1717
#include "llvm/ADT/STLExtras.h"
18+
#include "llvm/ADT/SmallBitVector.h"
1819

1920
namespace mlir {
2021
namespace presburger {
@@ -120,6 +121,9 @@ SmallVector<int64_t, 8> getDivUpperBound(ArrayRef<int64_t> dividend,
120121
SmallVector<int64_t, 8> getDivLowerBound(ArrayRef<int64_t> dividend,
121122
int64_t divisor, unsigned localVarIdx);
122123

124+
llvm::SmallBitVector getSubrangeBitVector(unsigned len, unsigned setOffset,
125+
unsigned numSet);
126+
123127
/// Check if the pos^th variable can be expressed as a floordiv of an affine
124128
/// function of other variables (where the divisor is a positive constant).
125129
/// `foundRepr` contains a boolean for each variable indicating if the

mlir/lib/Analysis/Presburger/Simplex.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,28 @@ scaleAndAddForAssert(ArrayRef<int64_t> a, int64_t scale, ArrayRef<int64_t> b) {
3131
return res;
3232
}
3333

34-
SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM, unsigned symbolOffset,
35-
unsigned nSymbol)
36-
: usingBigM(mustUseBigM), nRedundant(0), nSymbol(nSymbol),
34+
SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM)
35+
: usingBigM(mustUseBigM), nRedundant(0), nSymbol(0),
3736
tableau(0, getNumFixedCols() + nVar), empty(false) {
38-
assert(symbolOffset + nSymbol <= nVar);
39-
4037
colUnknown.insert(colUnknown.begin(), getNumFixedCols(), nullIndex);
4138
for (unsigned i = 0; i < nVar; ++i) {
4239
var.emplace_back(Orientation::Column, /*restricted=*/false,
4340
/*pos=*/getNumFixedCols() + i);
4441
colUnknown.push_back(i);
4542
}
43+
}
4644

47-
// Move the symbols to be in columns [3, 3 + nSymbol).
48-
for (unsigned i = 0; i < nSymbol; ++i) {
49-
var[symbolOffset + i].isSymbol = true;
50-
swapColumns(var[symbolOffset + i].pos, getNumFixedCols() + i);
45+
SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM,
46+
const llvm::SmallBitVector &isSymbol)
47+
: SimplexBase(nVar, mustUseBigM) {
48+
assert(isSymbol.size() == nVar && "invalid bitmask!");
49+
// Invariant: nSymbol is the number of symbols that have been marked
50+
// already and these occupy the columns
51+
// [getNumFixedCols(), getNumFixedCols() + nSymbol).
52+
for (unsigned symbolIdx : isSymbol.set_bits()) {
53+
var[symbolIdx].isSymbol = true;
54+
swapColumns(var[symbolIdx].pos, getNumFixedCols() + nSymbol);
55+
++nSymbol;
5156
}
5257
}
5358

@@ -502,7 +507,7 @@ LogicalResult SymbolicLexSimplex::doNonBranchingPivots() {
502507
}
503508

504509
SymbolicLexMin SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
505-
SymbolicLexMin result(nSymbol, var.size() - nSymbol);
510+
SymbolicLexMin result(domainPoly.getSpace(), var.size() - nSymbol);
506511

507512
/// The algorithm is more naturally expressed recursively, but we implement
508513
/// it iteratively here to avoid potential issues with stack overflows in the

mlir/lib/Analysis/Presburger/Utils.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,14 @@ MaybeLocalRepr presburger::computeSingleVarRepr(
253253
return repr;
254254
}
255255

256+
llvm::SmallBitVector presburger::getSubrangeBitVector(unsigned len,
257+
unsigned setOffset,
258+
unsigned numSet) {
259+
llvm::SmallBitVector vec(len, false);
260+
vec.set(setOffset, setOffset + numSet);
261+
return vec;
262+
}
263+
256264
void presburger::removeDuplicateDivs(
257265
std::vector<SmallVector<int64_t, 8>> &divs,
258266
SmallVectorImpl<unsigned> &denoms, unsigned localOffset,

0 commit comments

Comments
 (0)