Skip to content
50 changes: 45 additions & 5 deletions mlir/include/mlir/Analysis/Presburger/Fraction.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
//
// This is a simple class to represent fractions. It supports multiplication,
// This is a simple class to represent fractions. It supports arithmetic,
// comparison, floor, and ceiling operations.
//
//===----------------------------------------------------------------------===//
Expand All @@ -30,15 +30,15 @@ struct Fraction {
Fraction() = default;

/// Construct a Fraction from a numerator and denominator.
Fraction(const MPInt &oNum, const MPInt &oDen) : num(oNum), den(oDen) {
Fraction(const MPInt &oNum, const MPInt &oDen = MPInt(1)) : num(oNum), den(oDen) {
if (den < 0) {
num = -num;
den = -den;
}
}
/// Overloads for passing literals.
Fraction(const MPInt &num, int64_t den) : Fraction(num, MPInt(den)) {}
Fraction(int64_t num, const MPInt &den) : Fraction(MPInt(num), den) {}
Fraction(const MPInt &num, int64_t den = 1) : Fraction(num, MPInt(den)) {}
Fraction(int64_t num, const MPInt &den = MPInt(1)) : Fraction(MPInt(num), den) {}
Fraction(int64_t num, int64_t den) : Fraction(MPInt(num), MPInt(den)) {}

// Return the value of the fraction as an integer. This should only be called
Expand All @@ -48,6 +48,10 @@ struct Fraction {
return num / den;
}

llvm::raw_ostream &print(llvm::raw_ostream &os) const {
return os << "(" << num << "/" << den << ")";
}

/// The numerator and denominator, respectively. The denominator is always
/// positive.
MPInt num{0}, den{1};
Expand Down Expand Up @@ -95,8 +99,44 @@ inline bool operator>=(const Fraction &x, const Fraction &y) {
return compare(x, y) >= 0;
}

inline Fraction reduce(const Fraction &f) {
if (f == Fraction(0))
return f;
MPInt g = gcd(f.num, f.den);
return Fraction(f.num / g, f.den / g);
}

inline Fraction operator*(const Fraction &x, const Fraction &y) {
return Fraction(x.num * y.num, x.den * y.den);
return reduce(Fraction(x.num * y.num, x.den * y.den));
}

inline Fraction operator/(const Fraction &x, const Fraction &y) {
return reduce(Fraction(x.num * y.den, x.den * y.num));
}

inline Fraction operator+(const Fraction &x, const Fraction &y) {
return reduce(Fraction(x.num * y.den + x.den * y.num, x.den * y.den));
}

inline Fraction operator-(const Fraction &x, const Fraction &y) {
return reduce(Fraction(x.num * y.den - x.den * y.num, x.den * y.den));
}

inline Fraction& operator+=(const Fraction &g, const Fraction &f) {
Fraction *r = NULL;
*r = g+f;
return *r;
}

inline Fraction& operator-=(const Fraction &g, const Fraction &f) {
Fraction *r = NULL;
*r = g-f;
return *r;
}

inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Fraction &x) {
x.print(os);
return os;
}

} // namespace presburger
Expand Down