diff --git a/src/constraint_solver/constraint_solver.h b/src/constraint_solver/constraint_solver.h index aa761a6f1e..e334c400b6 100644 --- a/src/constraint_solver/constraint_solver.h +++ b/src/constraint_solver/constraint_solver.h @@ -2867,6 +2867,9 @@ class Solver { // Internal. IntExpr* CastExpression(IntVar* const var) const; + bool IsADifference(IntExpr* expr, + IntExpr** const left, + IntExpr** const right); const string name_; const SolverParameters parameters_; diff --git a/src/constraint_solver/expr_cst.cc b/src/constraint_solver/expr_cst.cc index 7b81e96951..87d7066b7d 100644 --- a/src/constraint_solver/expr_cst.cc +++ b/src/constraint_solver/expr_cst.cc @@ -81,12 +81,24 @@ string EqualityExprCst::DebugString() const { Constraint* Solver::MakeEquality(IntExpr* const e, int64 v) { CHECK_EQ(this, e->solver()); - return RevAlloc(new EqualityExprCst(this, e, v)); + IntExpr* left = NULL; + IntExpr* right = NULL; + if (v == 0 && IsADifference(e, &left, &right)) { + return MakeEquality(left->Var(), right->Var()); + } else { + return RevAlloc(new EqualityExprCst(this, e, v)); + } } Constraint* Solver::MakeEquality(IntExpr* const e, int v) { CHECK_EQ(this, e->solver()); - return RevAlloc(new EqualityExprCst(this, e, v)); + IntExpr* left = NULL; + IntExpr* right = NULL; + if (v == 0 && IsADifference(e, &left, &right)) { + return MakeEquality(left->Var(), right->Var()); + } else { + return RevAlloc(new EqualityExprCst(this, e, v)); + } } //----------------------------------------------------------------------------- @@ -289,12 +301,24 @@ string DiffCst::DebugString() const { Constraint* Solver::MakeNonEquality(IntVar* const e, int64 v) { CHECK_EQ(this, e->solver()); - return RevAlloc(new DiffCst(this, e, v)); + IntExpr* left = NULL; + IntExpr* right = NULL; + if (v == 0 && IsADifference(e, &left, &right)) { + return MakeNonEquality(left->Var(), right->Var()); + } else { + return RevAlloc(new DiffCst(this, e, v)); + } } Constraint* Solver::MakeNonEquality(IntVar* const e, int v) { CHECK_EQ(this, e->solver()); - return RevAlloc(new DiffCst(this, e, v)); + IntExpr* left = NULL; + IntExpr* right = NULL; + if (v == 0 && IsADifference(e, &left, &right)) { + return MakeNonEquality(left->Var(), right->Var()); + } else { + return RevAlloc(new DiffCst(this, e, v)); + } } // ----- is_equal_cst Constraint ----- @@ -350,6 +374,11 @@ class IsEqualCstCt : public CastConstraint { } // namespace IntVar* Solver::MakeIsEqualCstVar(IntVar* const var, int64 value) { + IntExpr* left = NULL; + IntExpr* right = NULL; + if (value == 0 && IsADifference(var, &left, &right)) { + return MakeIsEqualVar(left, right); + } return var->IsEqual(value); } @@ -378,7 +407,13 @@ Constraint* Solver::MakeIsEqualCstCt(IntVar* const var, var, value, ModelCache::VAR_CONSTANT_IS_EQUAL); - return RevAlloc(new IsEqualCstCt(this, var, value, boolvar)); + IntExpr* left = NULL; + IntExpr* right = NULL; + if (value == 0 && IsADifference(var, &left, &right)) { + return MakeIsEqualCt(left, right, boolvar); + } else { + return RevAlloc(new IsEqualCstCt(this, var, value, boolvar)); + } } // ----- is_diff_cst Constraint ----- @@ -438,6 +473,11 @@ class IsDiffCstCt : public CastConstraint { } // namespace IntVar* Solver::MakeIsDifferentCstVar(IntVar* const var, int64 value) { + IntExpr* left = NULL; + IntExpr* right = NULL; + if (value == 0 && IsADifference(var, &left, &right)) { + return MakeIsDifferentVar(left, right); + } return var->IsDifferent(value); } @@ -470,7 +510,13 @@ Constraint* Solver::MakeIsDifferentCstCt(IntVar* const var, var, value, ModelCache::VAR_CONSTANT_IS_NOT_EQUAL); - return RevAlloc(new IsDiffCstCt(this, var, value, boolvar)); + IntExpr* left = NULL; + IntExpr* right = NULL; + if (value == 0 && IsADifference(var, &left, &right)) { + return MakeIsDifferentCt(left, right, boolvar); + } else { + return RevAlloc(new IsDiffCstCt(this, var, value, boolvar)); + } } // ----- is_greater_equal_cst Constraint ----- diff --git a/src/constraint_solver/expressions.cc b/src/constraint_solver/expressions.cc index 46118a544d..7e73d49389 100644 --- a/src/constraint_solver/expressions.cc +++ b/src/constraint_solver/expressions.cc @@ -3084,6 +3084,8 @@ class SubIntExpr : public BaseIntExpr { visitor->EndVisitIntegerExpression(ModelVisitor::kDifference, this); } + IntExpr* left() const { return left_; } + IntExpr* right() const { return right_; } private: IntExpr* const left_; IntExpr* const right_; @@ -6786,4 +6788,22 @@ IntVar* BaseIntExpr::CastToVar() { LinkVarExpr(solver(), this, var); return var; } + +// Discovery methods +bool Solver::IsADifference(IntExpr* expr, + IntExpr** const left, + IntExpr** const right) { + + if (expr->IsVar()) { + IntVar* const expr_var = expr->Var(); + expr = CastExpression(expr_var); + } + SubIntExpr* const sub_expr = dynamic_cast(expr); + if (sub_expr != NULL) { + *left = sub_expr->left(); + *right = sub_expr->right(); + return true; + } + return false; +} } // namespace operations_research diff --git a/src/constraint_solver/range_cst.cc b/src/constraint_solver/range_cst.cc index 28bde3f82b..1a701f732c 100644 --- a/src/constraint_solver/range_cst.cc +++ b/src/constraint_solver/range_cst.cc @@ -395,6 +395,11 @@ Constraint* Solver::MakeNonEquality(IntVar* const l, IntVar* const r) { CHECK(r != NULL) << "left expression NULL, maybe a bad cast"; CHECK_EQ(this, l->solver()); CHECK_EQ(this, r->solver()); + if (l->Bound()) { + return MakeNonEquality(r, l->Min()); + } else if (r->Bound()) { + return MakeNonEquality(l, r->Min()); + } return RevAlloc(new DiffVar(this, l, r)); } diff --git a/src/flatzinc/parser.cc b/src/flatzinc/parser.cc index 79529c9edb..22a5a48720 100644 --- a/src/flatzinc/parser.cc +++ b/src/flatzinc/parser.cc @@ -809,6 +809,44 @@ bool ParserState::Presolve(CtSpec* const spec) { spec->ReplaceArg(2, new AST::IntLit(bound)); return true; } + if (id == "int_abs" && + !ContainsKey(stored_constraints_, index) && + spec->Arg(0)->isIntVar() && + spec->Arg(1)->isIntVar()) { + abs_map_[spec->Arg(1)->getIntVar()] = spec->Arg(0)->getIntVar(); + stored_constraints_.insert(index); + return true; + } + if (id == "int_eq_reif") { + if (spec->Arg(0)->isIntVar() && + ContainsKey(abs_map_, spec->Arg(0)->getIntVar()) && + spec->Arg(1)->isInt() && + spec->Arg(1)->getInt() == 0) { + VLOG(1) << " - presolve: remove abs() in " << spec->DebugString(); + dynamic_cast(spec->Arg(0))->i = + abs_map_[spec->Arg(0)->getIntVar()]; + } + } + if (id == "int_ne_reif") { + if (spec->Arg(0)->isIntVar() && + ContainsKey(abs_map_, spec->Arg(0)->getIntVar()) && + spec->Arg(1)->isInt() && + spec->Arg(1)->getInt() == 0) { + VLOG(1) << " - presolve: remove abs() in " << spec->DebugString(); + dynamic_cast(spec->Arg(0))->i = + abs_map_[spec->Arg(0)->getIntVar()]; + } + } + if (id == "int_ne") { + if (spec->Arg(0)->isIntVar() && + ContainsKey(abs_map_, spec->Arg(0)->getIntVar()) && + spec->Arg(1)->isInt() && + spec->Arg(1)->getInt() == 0) { + VLOG(1) << " - presolve: remove abs() in " << spec->DebugString(); + dynamic_cast(spec->Arg(0))->i = + abs_map_[spec->Arg(0)->getIntVar()]; + } + } return false; } diff --git a/src/flatzinc/parser.h b/src/flatzinc/parser.h index 229084d383..82fdb3283d 100644 --- a/src/flatzinc/parser.h +++ b/src/flatzinc/parser.h @@ -205,6 +205,8 @@ class ParserState { hash_map > constraints_per_id_; std::vector > constraints_per_int_variables_; std::vector > constraints_per_bool_variables_; + hash_map abs_map_; + hash_map > differences_; }; AST::Node* ArrayOutput(AST::Call* ann);