1776 lines
61 KiB
C++
1776 lines
61 KiB
C++
// Copyright 2010-2024 Google LLC
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include <algorithm>
|
|
#include <cstdint>
|
|
#include <functional>
|
|
#include <limits>
|
|
#include <memory>
|
|
#include <numeric>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "absl/strings/str_format.h"
|
|
#include "absl/strings/str_join.h"
|
|
#include "ortools/base/logging.h"
|
|
#include "ortools/base/types.h"
|
|
#include "ortools/constraint_solver/constraint_solver.h"
|
|
#include "ortools/constraint_solver/constraint_solveri.h"
|
|
#include "ortools/util/range_minimum_query.h"
|
|
#include "ortools/util/string_array.h"
|
|
|
|
ABSL_FLAG(bool, cp_disable_element_cache, true,
|
|
"If true, caching for IntElement is disabled.");
|
|
|
|
namespace operations_research {
|
|
|
|
// ----- IntExprElement -----
|
|
void LinkVarExpr(Solver* s, IntExpr* expr, IntVar* var);
|
|
|
|
namespace {
|
|
|
|
template <class T>
|
|
class VectorLess {
|
|
public:
|
|
explicit VectorLess(const std::vector<T>* values) : values_(values) {}
|
|
bool operator()(const T& x, const T& y) const {
|
|
return (*values_)[x] < (*values_)[y];
|
|
}
|
|
|
|
private:
|
|
const std::vector<T>* values_;
|
|
};
|
|
|
|
template <class T>
|
|
class VectorGreater {
|
|
public:
|
|
explicit VectorGreater(const std::vector<T>* values) : values_(values) {}
|
|
bool operator()(const T& x, const T& y) const {
|
|
return (*values_)[x] > (*values_)[y];
|
|
}
|
|
|
|
private:
|
|
const std::vector<T>* values_;
|
|
};
|
|
|
|
// ----- BaseIntExprElement -----
|
|
|
|
class BaseIntExprElement : public BaseIntExpr {
|
|
public:
|
|
BaseIntExprElement(Solver* s, IntVar* e);
|
|
~BaseIntExprElement() override {}
|
|
int64_t Min() const override;
|
|
int64_t Max() const override;
|
|
void Range(int64_t* mi, int64_t* ma) override;
|
|
void SetMin(int64_t m) override;
|
|
void SetMax(int64_t m) override;
|
|
void SetRange(int64_t mi, int64_t ma) override;
|
|
bool Bound() const override { return (expr_->Bound()); }
|
|
// TODO(user) : improve me, the previous test is not always true
|
|
void WhenRange(Demon* d) override { expr_->WhenRange(d); }
|
|
|
|
protected:
|
|
virtual int64_t ElementValue(int index) const = 0;
|
|
virtual int64_t ExprMin() const = 0;
|
|
virtual int64_t ExprMax() const = 0;
|
|
|
|
IntVar* const expr_;
|
|
|
|
private:
|
|
void UpdateSupports() const;
|
|
template <typename T>
|
|
void UpdateElementIndexBounds(T check_value) {
|
|
const int64_t emin = ExprMin();
|
|
const int64_t emax = ExprMax();
|
|
int64_t nmin = emin;
|
|
int64_t value = ElementValue(nmin);
|
|
while (nmin < emax && check_value(value)) {
|
|
nmin++;
|
|
value = ElementValue(nmin);
|
|
}
|
|
if (nmin == emax && check_value(value)) {
|
|
solver()->Fail();
|
|
}
|
|
int64_t nmax = emax;
|
|
value = ElementValue(nmax);
|
|
while (nmax >= nmin && check_value(value)) {
|
|
nmax--;
|
|
value = ElementValue(nmax);
|
|
}
|
|
expr_->SetRange(nmin, nmax);
|
|
}
|
|
|
|
mutable int64_t min_;
|
|
mutable int min_support_;
|
|
mutable int64_t max_;
|
|
mutable int max_support_;
|
|
mutable bool initial_update_;
|
|
IntVarIterator* const expr_iterator_;
|
|
};
|
|
|
|
BaseIntExprElement::BaseIntExprElement(Solver* const s, IntVar* const e)
|
|
: BaseIntExpr(s),
|
|
expr_(e),
|
|
min_(0),
|
|
min_support_(-1),
|
|
max_(0),
|
|
max_support_(-1),
|
|
initial_update_(true),
|
|
expr_iterator_(expr_->MakeDomainIterator(true)) {
|
|
CHECK(s != nullptr);
|
|
CHECK(e != nullptr);
|
|
}
|
|
|
|
int64_t BaseIntExprElement::Min() const {
|
|
UpdateSupports();
|
|
return min_;
|
|
}
|
|
|
|
int64_t BaseIntExprElement::Max() const {
|
|
UpdateSupports();
|
|
return max_;
|
|
}
|
|
|
|
void BaseIntExprElement::Range(int64_t* mi, int64_t* ma) {
|
|
UpdateSupports();
|
|
*mi = min_;
|
|
*ma = max_;
|
|
}
|
|
|
|
void BaseIntExprElement::SetMin(int64_t m) {
|
|
UpdateElementIndexBounds([m](int64_t value) { return value < m; });
|
|
}
|
|
|
|
void BaseIntExprElement::SetMax(int64_t m) {
|
|
UpdateElementIndexBounds([m](int64_t value) { return value > m; });
|
|
}
|
|
|
|
void BaseIntExprElement::SetRange(int64_t mi, int64_t ma) {
|
|
if (mi > ma) {
|
|
solver()->Fail();
|
|
}
|
|
UpdateElementIndexBounds(
|
|
[mi, ma](int64_t value) { return value < mi || value > ma; });
|
|
}
|
|
|
|
void BaseIntExprElement::UpdateSupports() const {
|
|
if (initial_update_ || !expr_->Contains(min_support_) ||
|
|
!expr_->Contains(max_support_)) {
|
|
const int64_t emin = ExprMin();
|
|
const int64_t emax = ExprMax();
|
|
int64_t min_value = ElementValue(emax);
|
|
int64_t max_value = min_value;
|
|
int min_support = emax;
|
|
int max_support = emax;
|
|
const uint64_t expr_size = expr_->Size();
|
|
if (expr_size > 1) {
|
|
if (expr_size == emax - emin + 1) {
|
|
// Value(emax) already stored in min_value, max_value.
|
|
for (int64_t index = emin; index < emax; ++index) {
|
|
const int64_t value = ElementValue(index);
|
|
if (value > max_value) {
|
|
max_value = value;
|
|
max_support = index;
|
|
} else if (value < min_value) {
|
|
min_value = value;
|
|
min_support = index;
|
|
}
|
|
}
|
|
} else {
|
|
for (const int64_t index : InitAndGetValues(expr_iterator_)) {
|
|
if (index >= emin && index <= emax) {
|
|
const int64_t value = ElementValue(index);
|
|
if (value > max_value) {
|
|
max_value = value;
|
|
max_support = index;
|
|
} else if (value < min_value) {
|
|
min_value = value;
|
|
min_support = index;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Solver* s = solver();
|
|
s->SaveAndSetValue(&min_, min_value);
|
|
s->SaveAndSetValue(&min_support_, min_support);
|
|
s->SaveAndSetValue(&max_, max_value);
|
|
s->SaveAndSetValue(&max_support_, max_support);
|
|
s->SaveAndSetValue(&initial_update_, false);
|
|
}
|
|
}
|
|
|
|
// ----- IntElementConstraint -----
|
|
|
|
// This constraint implements 'elem' == 'values'['index'].
|
|
// It scans the bounds of 'elem' to propagate on the domain of 'index'.
|
|
// It scans the domain of 'index' to compute the new bounds of 'elem'.
|
|
class IntElementConstraint : public CastConstraint {
|
|
public:
|
|
IntElementConstraint(Solver* const s, const std::vector<int64_t>& values,
|
|
IntVar* const index, IntVar* const elem)
|
|
: CastConstraint(s, elem),
|
|
values_(values),
|
|
index_(index),
|
|
index_iterator_(index_->MakeDomainIterator(true)) {
|
|
CHECK(index != nullptr);
|
|
}
|
|
|
|
void Post() override {
|
|
Demon* const d =
|
|
solver()->MakeDelayedConstraintInitialPropagateCallback(this);
|
|
index_->WhenDomain(d);
|
|
target_var_->WhenRange(d);
|
|
}
|
|
|
|
void InitialPropagate() override {
|
|
index_->SetRange(0, values_.size() - 1);
|
|
const int64_t target_var_min = target_var_->Min();
|
|
const int64_t target_var_max = target_var_->Max();
|
|
int64_t new_min = target_var_max;
|
|
int64_t new_max = target_var_min;
|
|
to_remove_.clear();
|
|
for (const int64_t index : InitAndGetValues(index_iterator_)) {
|
|
const int64_t value = values_[index];
|
|
if (value < target_var_min || value > target_var_max) {
|
|
to_remove_.push_back(index);
|
|
} else {
|
|
if (value < new_min) {
|
|
new_min = value;
|
|
}
|
|
if (value > new_max) {
|
|
new_max = value;
|
|
}
|
|
}
|
|
}
|
|
target_var_->SetRange(new_min, new_max);
|
|
if (!to_remove_.empty()) {
|
|
index_->RemoveValues(to_remove_);
|
|
}
|
|
}
|
|
|
|
std::string DebugString() const override {
|
|
return absl::StrFormat("IntElementConstraint(%s, %s, %s)",
|
|
absl::StrJoin(values_, ", "), index_->DebugString(),
|
|
target_var_->DebugString());
|
|
}
|
|
|
|
void Accept(ModelVisitor* const visitor) const override {
|
|
visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
|
|
visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
|
|
visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
|
|
index_);
|
|
visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
|
|
target_var_);
|
|
visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
|
|
}
|
|
|
|
private:
|
|
const std::vector<int64_t> values_;
|
|
IntVar* const index_;
|
|
IntVarIterator* const index_iterator_;
|
|
std::vector<int64_t> to_remove_;
|
|
};
|
|
|
|
// ----- IntExprElement
|
|
|
|
IntVar* BuildDomainIntVar(Solver* solver, std::vector<int64_t>* values);
|
|
|
|
class IntExprElement : public BaseIntExprElement {
|
|
public:
|
|
IntExprElement(Solver* const s, const std::vector<int64_t>& vals,
|
|
IntVar* const expr)
|
|
: BaseIntExprElement(s, expr), values_(vals) {}
|
|
|
|
~IntExprElement() override {}
|
|
|
|
std::string name() const override {
|
|
const int size = values_.size();
|
|
if (size > 10) {
|
|
return absl::StrFormat("IntElement(array of size %d, %s)", size,
|
|
expr_->name());
|
|
} else {
|
|
return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
|
|
expr_->name());
|
|
}
|
|
}
|
|
|
|
std::string DebugString() const override {
|
|
const int size = values_.size();
|
|
if (size > 10) {
|
|
return absl::StrFormat("IntElement(array of size %d, %s)", size,
|
|
expr_->DebugString());
|
|
} else {
|
|
return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
|
|
expr_->DebugString());
|
|
}
|
|
}
|
|
|
|
IntVar* CastToVar() override {
|
|
Solver* const s = solver();
|
|
IntVar* const var = s->MakeIntVar(values_);
|
|
s->AddCastConstraint(
|
|
s->RevAlloc(new IntElementConstraint(s, values_, expr_, var)), var,
|
|
this);
|
|
return var;
|
|
}
|
|
|
|
void Accept(ModelVisitor* const visitor) const override {
|
|
visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
|
|
visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
|
|
visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
|
|
expr_);
|
|
visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
|
|
}
|
|
|
|
protected:
|
|
int64_t ElementValue(int index) const override {
|
|
DCHECK_LT(index, values_.size());
|
|
return values_[index];
|
|
}
|
|
int64_t ExprMin() const override {
|
|
return std::max<int64_t>(0, expr_->Min());
|
|
}
|
|
int64_t ExprMax() const override {
|
|
return values_.empty()
|
|
? 0
|
|
: std::min<int64_t>(values_.size() - 1, expr_->Max());
|
|
}
|
|
|
|
private:
|
|
const std::vector<int64_t> values_;
|
|
};
|
|
|
|
// ----- Range Minimum Query-based Element -----
|
|
|
|
class RangeMinimumQueryExprElement : public BaseIntExpr {
|
|
public:
|
|
RangeMinimumQueryExprElement(Solver* solver,
|
|
const std::vector<int64_t>& values,
|
|
IntVar* index);
|
|
~RangeMinimumQueryExprElement() override {}
|
|
int64_t Min() const override;
|
|
int64_t Max() const override;
|
|
void Range(int64_t* mi, int64_t* ma) override;
|
|
void SetMin(int64_t m) override;
|
|
void SetMax(int64_t m) override;
|
|
void SetRange(int64_t mi, int64_t ma) override;
|
|
bool Bound() const override { return (index_->Bound()); }
|
|
// TODO(user) : improve me, the previous test is not always true
|
|
void WhenRange(Demon* d) override { index_->WhenRange(d); }
|
|
IntVar* CastToVar() override {
|
|
// TODO(user): Should we try to make holes in the domain of index_, as we
|
|
// do here, or should we only propagate bounds as we do in
|
|
// IncreasingIntExprElement ?
|
|
IntVar* const var = solver()->MakeIntVar(min_rmq_.array());
|
|
solver()->AddCastConstraint(solver()->RevAlloc(new IntElementConstraint(
|
|
solver(), min_rmq_.array(), index_, var)),
|
|
var, this);
|
|
return var;
|
|
}
|
|
void Accept(ModelVisitor* const visitor) const override {
|
|
visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
|
|
visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
|
|
min_rmq_.array());
|
|
visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
|
|
index_);
|
|
visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
|
|
}
|
|
|
|
private:
|
|
int64_t IndexMin() const { return std::max<int64_t>(0, index_->Min()); }
|
|
int64_t IndexMax() const {
|
|
return std::min<int64_t>(min_rmq_.array().size() - 1, index_->Max());
|
|
}
|
|
|
|
IntVar* const index_;
|
|
const RangeMinimumQuery<int64_t, std::less<int64_t>> min_rmq_;
|
|
const RangeMinimumQuery<int64_t, std::greater<int64_t>> max_rmq_;
|
|
};
|
|
|
|
RangeMinimumQueryExprElement::RangeMinimumQueryExprElement(
|
|
Solver* solver, const std::vector<int64_t>& values, IntVar* index)
|
|
: BaseIntExpr(solver), index_(index), min_rmq_(values), max_rmq_(values) {
|
|
CHECK(solver != nullptr);
|
|
CHECK(index != nullptr);
|
|
}
|
|
|
|
int64_t RangeMinimumQueryExprElement::Min() const {
|
|
return min_rmq_.GetMinimumFromRange(IndexMin(), IndexMax() + 1);
|
|
}
|
|
|
|
int64_t RangeMinimumQueryExprElement::Max() const {
|
|
return max_rmq_.GetMinimumFromRange(IndexMin(), IndexMax() + 1);
|
|
}
|
|
|
|
void RangeMinimumQueryExprElement::Range(int64_t* mi, int64_t* ma) {
|
|
const int64_t range_min = IndexMin();
|
|
const int64_t range_max = IndexMax() + 1;
|
|
*mi = min_rmq_.GetMinimumFromRange(range_min, range_max);
|
|
*ma = max_rmq_.GetMinimumFromRange(range_min, range_max);
|
|
}
|
|
|
|
#define UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(test) \
|
|
const std::vector<int64_t>& values = min_rmq_.array(); \
|
|
int64_t index_min = IndexMin(); \
|
|
int64_t index_max = IndexMax(); \
|
|
int64_t value = values[index_min]; \
|
|
while (index_min < index_max && (test)) { \
|
|
index_min++; \
|
|
value = values[index_min]; \
|
|
} \
|
|
if (index_min == index_max && (test)) { \
|
|
solver()->Fail(); \
|
|
} \
|
|
value = values[index_max]; \
|
|
while (index_max >= index_min && (test)) { \
|
|
index_max--; \
|
|
value = values[index_max]; \
|
|
} \
|
|
index_->SetRange(index_min, index_max);
|
|
|
|
void RangeMinimumQueryExprElement::SetMin(int64_t m) {
|
|
UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(value < m);
|
|
}
|
|
|
|
void RangeMinimumQueryExprElement::SetMax(int64_t m) {
|
|
UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(value > m);
|
|
}
|
|
|
|
void RangeMinimumQueryExprElement::SetRange(int64_t mi, int64_t ma) {
|
|
if (mi > ma) {
|
|
solver()->Fail();
|
|
}
|
|
UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(value < mi || value > ma);
|
|
}
|
|
|
|
#undef UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS
|
|
|
|
// ----- Increasing Element -----
|
|
|
|
class IncreasingIntExprElement : public BaseIntExpr {
|
|
public:
|
|
IncreasingIntExprElement(Solver* s, const std::vector<int64_t>& values,
|
|
IntVar* index);
|
|
~IncreasingIntExprElement() override {}
|
|
|
|
int64_t Min() const override;
|
|
void SetMin(int64_t m) override;
|
|
int64_t Max() const override;
|
|
void SetMax(int64_t m) override;
|
|
void SetRange(int64_t mi, int64_t ma) override;
|
|
bool Bound() const override { return (index_->Bound()); }
|
|
// TODO(user) : improve me, the previous test is not always true
|
|
std::string name() const override {
|
|
return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
|
|
index_->name());
|
|
}
|
|
std::string DebugString() const override {
|
|
return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
|
|
index_->DebugString());
|
|
}
|
|
|
|
void Accept(ModelVisitor* const visitor) const override {
|
|
visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
|
|
visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
|
|
visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
|
|
index_);
|
|
visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
|
|
}
|
|
|
|
void WhenRange(Demon* d) override { index_->WhenRange(d); }
|
|
|
|
IntVar* CastToVar() override {
|
|
Solver* const s = solver();
|
|
IntVar* const var = s->MakeIntVar(values_);
|
|
LinkVarExpr(s, this, var);
|
|
return var;
|
|
}
|
|
|
|
private:
|
|
const std::vector<int64_t> values_;
|
|
IntVar* const index_;
|
|
};
|
|
|
|
IncreasingIntExprElement::IncreasingIntExprElement(
|
|
Solver* const s, const std::vector<int64_t>& values, IntVar* const index)
|
|
: BaseIntExpr(s), values_(values), index_(index) {
|
|
DCHECK(index);
|
|
DCHECK(s);
|
|
}
|
|
|
|
int64_t IncreasingIntExprElement::Min() const {
|
|
const int64_t expression_min = std::max<int64_t>(0, index_->Min());
|
|
return (expression_min < values_.size()
|
|
? values_[expression_min]
|
|
: std::numeric_limits<int64_t>::max());
|
|
}
|
|
|
|
void IncreasingIntExprElement::SetMin(int64_t m) {
|
|
const int64_t index_min = std::max<int64_t>(0, index_->Min());
|
|
const int64_t index_max =
|
|
std::min<int64_t>(values_.size() - 1, index_->Max());
|
|
|
|
if (index_min > index_max || m > values_[index_max]) {
|
|
solver()->Fail();
|
|
}
|
|
|
|
const std::vector<int64_t>::const_iterator first =
|
|
std::lower_bound(values_.begin(), values_.end(), m);
|
|
const int64_t new_index_min = first - values_.begin();
|
|
index_->SetMin(new_index_min);
|
|
}
|
|
|
|
int64_t IncreasingIntExprElement::Max() const {
|
|
const int64_t expression_max =
|
|
std::min<int64_t>(values_.size() - 1, index_->Max());
|
|
return (expression_max >= 0 ? values_[expression_max]
|
|
: std::numeric_limits<int64_t>::max());
|
|
}
|
|
|
|
void IncreasingIntExprElement::SetMax(int64_t m) {
|
|
int64_t index_min = std::max<int64_t>(0, index_->Min());
|
|
if (m < values_[index_min]) {
|
|
solver()->Fail();
|
|
}
|
|
|
|
const std::vector<int64_t>::const_iterator last_after =
|
|
std::upper_bound(values_.begin(), values_.end(), m);
|
|
const int64_t new_index_max = (last_after - values_.begin()) - 1;
|
|
index_->SetRange(0, new_index_max);
|
|
}
|
|
|
|
void IncreasingIntExprElement::SetRange(int64_t mi, int64_t ma) {
|
|
if (mi > ma) {
|
|
solver()->Fail();
|
|
}
|
|
const int64_t index_min = std::max<int64_t>(0, index_->Min());
|
|
const int64_t index_max =
|
|
std::min<int64_t>(values_.size() - 1, index_->Max());
|
|
|
|
if (mi > ma || ma < values_[index_min] || mi > values_[index_max]) {
|
|
solver()->Fail();
|
|
}
|
|
|
|
const std::vector<int64_t>::const_iterator first =
|
|
std::lower_bound(values_.begin(), values_.end(), mi);
|
|
const int64_t new_index_min = first - values_.begin();
|
|
|
|
const std::vector<int64_t>::const_iterator last_after =
|
|
std::upper_bound(first, values_.end(), ma);
|
|
const int64_t new_index_max = (last_after - values_.begin()) - 1;
|
|
|
|
// Assign.
|
|
index_->SetRange(new_index_min, new_index_max);
|
|
}
|
|
|
|
// ----- Solver::MakeElement(int array, int var) -----
|
|
IntExpr* BuildElement(Solver* const solver, const std::vector<int64_t>& values,
|
|
IntVar* const index) {
|
|
// Various checks.
|
|
// Is array constant?
|
|
if (IsArrayConstant(values, values[0])) {
|
|
solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
|
|
return solver->MakeIntConst(values[0]);
|
|
}
|
|
// Is array built with booleans only?
|
|
// TODO(user): We could maintain the index of the first one.
|
|
if (IsArrayBoolean(values)) {
|
|
std::vector<int64_t> ones;
|
|
int first_zero = -1;
|
|
for (int i = 0; i < values.size(); ++i) {
|
|
if (values[i] == 1) {
|
|
ones.push_back(i);
|
|
} else {
|
|
first_zero = i;
|
|
}
|
|
}
|
|
if (ones.size() == 1) {
|
|
DCHECK_EQ(int64_t{1}, values[ones.back()]);
|
|
solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
|
|
return solver->MakeIsEqualCstVar(index, ones.back());
|
|
} else if (ones.size() == values.size() - 1) {
|
|
solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
|
|
return solver->MakeIsDifferentCstVar(index, first_zero);
|
|
} else if (ones.size() == ones.back() - ones.front() + 1) { // contiguous.
|
|
solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
|
|
IntVar* const b = solver->MakeBoolVar("ContiguousBooleanElementVar");
|
|
solver->AddConstraint(
|
|
solver->MakeIsBetweenCt(index, ones.front(), ones.back(), b));
|
|
return b;
|
|
} else {
|
|
IntVar* const b = solver->MakeBoolVar("NonContiguousBooleanElementVar");
|
|
solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
|
|
solver->AddConstraint(solver->MakeIsMemberCt(index, ones, b));
|
|
return b;
|
|
}
|
|
}
|
|
IntExpr* cache = nullptr;
|
|
if (!absl::GetFlag(FLAGS_cp_disable_element_cache)) {
|
|
cache = solver->Cache()->FindVarConstantArrayExpression(
|
|
index, values, ModelCache::VAR_CONSTANT_ARRAY_ELEMENT);
|
|
}
|
|
if (cache != nullptr) {
|
|
return cache;
|
|
} else {
|
|
IntExpr* result = nullptr;
|
|
if (values.size() >= 2 && index->Min() == 0 && index->Max() == 1) {
|
|
result = solver->MakeSum(solver->MakeProd(index, values[1] - values[0]),
|
|
values[0]);
|
|
} else if (values.size() == 2 && index->Contains(0) && index->Contains(1)) {
|
|
solver->AddConstraint(solver->MakeBetweenCt(index, 0, 1));
|
|
result = solver->MakeSum(solver->MakeProd(index, values[1] - values[0]),
|
|
values[0]);
|
|
} else if (IsIncreasingContiguous(values)) {
|
|
result = solver->MakeSum(index, values[0]);
|
|
} else if (IsIncreasing(values)) {
|
|
result = solver->RegisterIntExpr(solver->RevAlloc(
|
|
new IncreasingIntExprElement(solver, values, index)));
|
|
} else {
|
|
if (solver->parameters().use_element_rmq()) {
|
|
result = solver->RegisterIntExpr(solver->RevAlloc(
|
|
new RangeMinimumQueryExprElement(solver, values, index)));
|
|
} else {
|
|
result = solver->RegisterIntExpr(
|
|
solver->RevAlloc(new IntExprElement(solver, values, index)));
|
|
}
|
|
}
|
|
if (!absl::GetFlag(FLAGS_cp_disable_element_cache)) {
|
|
solver->Cache()->InsertVarConstantArrayExpression(
|
|
result, index, values, ModelCache::VAR_CONSTANT_ARRAY_ELEMENT);
|
|
}
|
|
return result;
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
IntExpr* Solver::MakeElement(const std::vector<int64_t>& values,
|
|
IntVar* const index) {
|
|
DCHECK(index);
|
|
DCHECK_EQ(this, index->solver());
|
|
if (index->Bound()) {
|
|
return MakeIntConst(values[index->Min()]);
|
|
}
|
|
return BuildElement(this, values, index);
|
|
}
|
|
|
|
IntExpr* Solver::MakeElement(const std::vector<int>& values,
|
|
IntVar* const index) {
|
|
DCHECK(index);
|
|
DCHECK_EQ(this, index->solver());
|
|
if (index->Bound()) {
|
|
return MakeIntConst(values[index->Min()]);
|
|
}
|
|
return BuildElement(this, ToInt64Vector(values), index);
|
|
}
|
|
|
|
// ----- IntExprFunctionElement -----
|
|
|
|
namespace {
|
|
class IntExprFunctionElement : public BaseIntExprElement {
|
|
public:
|
|
IntExprFunctionElement(Solver* s, Solver::IndexEvaluator1 values, IntVar* e);
|
|
~IntExprFunctionElement() override;
|
|
|
|
std::string name() const override {
|
|
return absl::StrFormat("IntFunctionElement(%s)", expr_->name());
|
|
}
|
|
|
|
std::string DebugString() const override {
|
|
return absl::StrFormat("IntFunctionElement(%s)", expr_->DebugString());
|
|
}
|
|
|
|
void Accept(ModelVisitor* const visitor) const override {
|
|
// Warning: This will expand all values into a vector.
|
|
visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
|
|
visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
|
|
expr_);
|
|
visitor->VisitInt64ToInt64Extension(values_, expr_->Min(), expr_->Max());
|
|
visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
|
|
}
|
|
|
|
protected:
|
|
int64_t ElementValue(int index) const override { return values_(index); }
|
|
int64_t ExprMin() const override { return expr_->Min(); }
|
|
int64_t ExprMax() const override { return expr_->Max(); }
|
|
|
|
private:
|
|
Solver::IndexEvaluator1 values_;
|
|
};
|
|
|
|
IntExprFunctionElement::IntExprFunctionElement(Solver* const s,
|
|
Solver::IndexEvaluator1 values,
|
|
IntVar* const e)
|
|
: BaseIntExprElement(s, e), values_(std::move(values)) {
|
|
CHECK(values_ != nullptr);
|
|
}
|
|
|
|
IntExprFunctionElement::~IntExprFunctionElement() {}
|
|
|
|
// ----- Increasing Element -----
|
|
|
|
class IncreasingIntExprFunctionElement : public BaseIntExpr {
|
|
public:
|
|
IncreasingIntExprFunctionElement(Solver* const s,
|
|
Solver::IndexEvaluator1 values,
|
|
IntVar* const index)
|
|
: BaseIntExpr(s), values_(std::move(values)), index_(index) {
|
|
DCHECK(values_ != nullptr);
|
|
DCHECK(index);
|
|
DCHECK(s);
|
|
}
|
|
|
|
~IncreasingIntExprFunctionElement() override {}
|
|
|
|
int64_t Min() const override { return values_(index_->Min()); }
|
|
|
|
void SetMin(int64_t m) override {
|
|
const int64_t index_min = index_->Min();
|
|
const int64_t index_max = index_->Max();
|
|
if (m > values_(index_max)) {
|
|
solver()->Fail();
|
|
}
|
|
const int64_t new_index_min = FindNewIndexMin(index_min, index_max, m);
|
|
index_->SetMin(new_index_min);
|
|
}
|
|
|
|
int64_t Max() const override { return values_(index_->Max()); }
|
|
|
|
void SetMax(int64_t m) override {
|
|
int64_t index_min = index_->Min();
|
|
int64_t index_max = index_->Max();
|
|
if (m < values_(index_min)) {
|
|
solver()->Fail();
|
|
}
|
|
const int64_t new_index_max = FindNewIndexMax(index_min, index_max, m);
|
|
index_->SetMax(new_index_max);
|
|
}
|
|
|
|
void SetRange(int64_t mi, int64_t ma) override {
|
|
const int64_t index_min = index_->Min();
|
|
const int64_t index_max = index_->Max();
|
|
const int64_t value_min = values_(index_min);
|
|
const int64_t value_max = values_(index_max);
|
|
if (mi > ma || ma < value_min || mi > value_max) {
|
|
solver()->Fail();
|
|
}
|
|
if (mi <= value_min && ma >= value_max) {
|
|
// Nothing to do.
|
|
return;
|
|
}
|
|
|
|
const int64_t new_index_min = FindNewIndexMin(index_min, index_max, mi);
|
|
const int64_t new_index_max = FindNewIndexMax(new_index_min, index_max, ma);
|
|
// Assign.
|
|
index_->SetRange(new_index_min, new_index_max);
|
|
}
|
|
|
|
std::string name() const override {
|
|
return absl::StrFormat("IncreasingIntExprFunctionElement(values, %s)",
|
|
index_->name());
|
|
}
|
|
|
|
std::string DebugString() const override {
|
|
return absl::StrFormat("IncreasingIntExprFunctionElement(values, %s)",
|
|
index_->DebugString());
|
|
}
|
|
|
|
void WhenRange(Demon* d) override { index_->WhenRange(d); }
|
|
|
|
void Accept(ModelVisitor* const visitor) const override {
|
|
// Warning: This will expand all values into a vector.
|
|
visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
|
|
visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
|
|
index_);
|
|
if (index_->Min() == 0) {
|
|
visitor->VisitInt64ToInt64AsArray(values_, ModelVisitor::kValuesArgument,
|
|
index_->Max());
|
|
} else {
|
|
visitor->VisitInt64ToInt64Extension(values_, index_->Min(),
|
|
index_->Max());
|
|
}
|
|
visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
|
|
}
|
|
|
|
private:
|
|
int64_t FindNewIndexMin(int64_t index_min, int64_t index_max, int64_t m) {
|
|
if (m <= values_(index_min)) {
|
|
return index_min;
|
|
}
|
|
|
|
DCHECK_LT(values_(index_min), m);
|
|
DCHECK_GE(values_(index_max), m);
|
|
|
|
int64_t index_lower_bound = index_min;
|
|
int64_t index_upper_bound = index_max;
|
|
while (index_upper_bound - index_lower_bound > 1) {
|
|
DCHECK_LT(values_(index_lower_bound), m);
|
|
DCHECK_GE(values_(index_upper_bound), m);
|
|
const int64_t pivot = (index_lower_bound + index_upper_bound) / 2;
|
|
const int64_t pivot_value = values_(pivot);
|
|
if (pivot_value < m) {
|
|
index_lower_bound = pivot;
|
|
} else {
|
|
index_upper_bound = pivot;
|
|
}
|
|
}
|
|
DCHECK(values_(index_upper_bound) >= m);
|
|
return index_upper_bound;
|
|
}
|
|
|
|
int64_t FindNewIndexMax(int64_t index_min, int64_t index_max, int64_t m) {
|
|
if (m >= values_(index_max)) {
|
|
return index_max;
|
|
}
|
|
|
|
DCHECK_LE(values_(index_min), m);
|
|
DCHECK_GT(values_(index_max), m);
|
|
|
|
int64_t index_lower_bound = index_min;
|
|
int64_t index_upper_bound = index_max;
|
|
while (index_upper_bound - index_lower_bound > 1) {
|
|
DCHECK_LE(values_(index_lower_bound), m);
|
|
DCHECK_GT(values_(index_upper_bound), m);
|
|
const int64_t pivot = (index_lower_bound + index_upper_bound) / 2;
|
|
const int64_t pivot_value = values_(pivot);
|
|
if (pivot_value > m) {
|
|
index_upper_bound = pivot;
|
|
} else {
|
|
index_lower_bound = pivot;
|
|
}
|
|
}
|
|
DCHECK(values_(index_lower_bound) <= m);
|
|
return index_lower_bound;
|
|
}
|
|
|
|
Solver::IndexEvaluator1 values_;
|
|
IntVar* const index_;
|
|
};
|
|
} // namespace
|
|
|
|
IntExpr* Solver::MakeElement(Solver::IndexEvaluator1 values,
|
|
IntVar* const index) {
|
|
CHECK_EQ(this, index->solver());
|
|
return RegisterIntExpr(
|
|
RevAlloc(new IntExprFunctionElement(this, std::move(values), index)));
|
|
}
|
|
|
|
IntExpr* Solver::MakeMonotonicElement(Solver::IndexEvaluator1 values,
|
|
bool increasing, IntVar* const index) {
|
|
CHECK_EQ(this, index->solver());
|
|
if (increasing) {
|
|
return RegisterIntExpr(
|
|
RevAlloc(new IncreasingIntExprFunctionElement(this, values, index)));
|
|
} else {
|
|
// You need to pass by copy such that opposite_value does not include a
|
|
// dandling reference when leaving this scope.
|
|
Solver::IndexEvaluator1 opposite_values = [values](int64_t i) {
|
|
return -values(i);
|
|
};
|
|
return RegisterIntExpr(MakeOpposite(RevAlloc(
|
|
new IncreasingIntExprFunctionElement(this, opposite_values, index))));
|
|
}
|
|
}
|
|
|
|
// ----- IntIntExprFunctionElement -----
|
|
|
|
namespace {
|
|
class IntIntExprFunctionElement : public BaseIntExpr {
|
|
public:
|
|
IntIntExprFunctionElement(Solver* s, Solver::IndexEvaluator2 values,
|
|
IntVar* expr1, IntVar* expr2);
|
|
~IntIntExprFunctionElement() override;
|
|
std::string DebugString() const override {
|
|
return absl::StrFormat("IntIntFunctionElement(%s,%s)",
|
|
expr1_->DebugString(), expr2_->DebugString());
|
|
}
|
|
int64_t Min() const override;
|
|
int64_t Max() const override;
|
|
void Range(int64_t* lower_bound, int64_t* upper_bound) override;
|
|
void SetMin(int64_t lower_bound) override;
|
|
void SetMax(int64_t upper_bound) override;
|
|
void SetRange(int64_t lower_bound, int64_t upper_bound) override;
|
|
bool Bound() const override { return expr1_->Bound() && expr2_->Bound(); }
|
|
// TODO(user) : improve me, the previous test is not always true
|
|
void WhenRange(Demon* d) override {
|
|
expr1_->WhenRange(d);
|
|
expr2_->WhenRange(d);
|
|
}
|
|
|
|
void Accept(ModelVisitor* const visitor) const override {
|
|
visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
|
|
visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
|
|
expr1_);
|
|
visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndex2Argument,
|
|
expr2_);
|
|
// Warning: This will expand all values into a vector.
|
|
const int64_t expr1_min = expr1_->Min();
|
|
const int64_t expr1_max = expr1_->Max();
|
|
visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, expr1_min);
|
|
visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, expr1_max);
|
|
for (int i = expr1_min; i <= expr1_max; ++i) {
|
|
visitor->VisitInt64ToInt64Extension(
|
|
[this, i](int64_t j) { return values_(i, j); }, expr2_->Min(),
|
|
expr2_->Max());
|
|
}
|
|
visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
|
|
}
|
|
|
|
private:
|
|
int64_t ElementValue(int index1, int index2) const {
|
|
return values_(index1, index2);
|
|
}
|
|
void UpdateSupports() const;
|
|
|
|
IntVar* const expr1_;
|
|
IntVar* const expr2_;
|
|
mutable int64_t min_;
|
|
mutable int min_support1_;
|
|
mutable int min_support2_;
|
|
mutable int64_t max_;
|
|
mutable int max_support1_;
|
|
mutable int max_support2_;
|
|
mutable bool initial_update_;
|
|
Solver::IndexEvaluator2 values_;
|
|
IntVarIterator* const expr1_iterator_;
|
|
IntVarIterator* const expr2_iterator_;
|
|
};
|
|
|
|
IntIntExprFunctionElement::IntIntExprFunctionElement(
|
|
Solver* const s, Solver::IndexEvaluator2 values, IntVar* const expr1,
|
|
IntVar* const expr2)
|
|
: BaseIntExpr(s),
|
|
expr1_(expr1),
|
|
expr2_(expr2),
|
|
min_(0),
|
|
min_support1_(-1),
|
|
min_support2_(-1),
|
|
max_(0),
|
|
max_support1_(-1),
|
|
max_support2_(-1),
|
|
initial_update_(true),
|
|
values_(std::move(values)),
|
|
expr1_iterator_(expr1_->MakeDomainIterator(true)),
|
|
expr2_iterator_(expr2_->MakeDomainIterator(true)) {
|
|
CHECK(values_ != nullptr);
|
|
}
|
|
|
|
IntIntExprFunctionElement::~IntIntExprFunctionElement() {}
|
|
|
|
int64_t IntIntExprFunctionElement::Min() const {
|
|
UpdateSupports();
|
|
return min_;
|
|
}
|
|
|
|
int64_t IntIntExprFunctionElement::Max() const {
|
|
UpdateSupports();
|
|
return max_;
|
|
}
|
|
|
|
void IntIntExprFunctionElement::Range(int64_t* lower_bound,
|
|
int64_t* upper_bound) {
|
|
UpdateSupports();
|
|
*lower_bound = min_;
|
|
*upper_bound = max_;
|
|
}
|
|
|
|
#define UPDATE_ELEMENT_INDEX_BOUNDS(test) \
|
|
const int64_t emin1 = expr1_->Min(); \
|
|
const int64_t emax1 = expr1_->Max(); \
|
|
const int64_t emin2 = expr2_->Min(); \
|
|
const int64_t emax2 = expr2_->Max(); \
|
|
int64_t nmin1 = emin1; \
|
|
bool found = false; \
|
|
while (nmin1 <= emax1 && !found) { \
|
|
for (int i = emin2; i <= emax2; ++i) { \
|
|
int64_t value = ElementValue(nmin1, i); \
|
|
if (test) { \
|
|
found = true; \
|
|
break; \
|
|
} \
|
|
} \
|
|
if (!found) { \
|
|
nmin1++; \
|
|
} \
|
|
} \
|
|
if (nmin1 > emax1) { \
|
|
solver()->Fail(); \
|
|
} \
|
|
int64_t nmin2 = emin2; \
|
|
found = false; \
|
|
while (nmin2 <= emax2 && !found) { \
|
|
for (int i = emin1; i <= emax1; ++i) { \
|
|
int64_t value = ElementValue(i, nmin2); \
|
|
if (test) { \
|
|
found = true; \
|
|
break; \
|
|
} \
|
|
} \
|
|
if (!found) { \
|
|
nmin2++; \
|
|
} \
|
|
} \
|
|
if (nmin2 > emax2) { \
|
|
solver()->Fail(); \
|
|
} \
|
|
int64_t nmax1 = emax1; \
|
|
found = false; \
|
|
while (nmax1 >= nmin1 && !found) { \
|
|
for (int i = emin2; i <= emax2; ++i) { \
|
|
int64_t value = ElementValue(nmax1, i); \
|
|
if (test) { \
|
|
found = true; \
|
|
break; \
|
|
} \
|
|
} \
|
|
if (!found) { \
|
|
nmax1--; \
|
|
} \
|
|
} \
|
|
int64_t nmax2 = emax2; \
|
|
found = false; \
|
|
while (nmax2 >= nmin2 && !found) { \
|
|
for (int i = emin1; i <= emax1; ++i) { \
|
|
int64_t value = ElementValue(i, nmax2); \
|
|
if (test) { \
|
|
found = true; \
|
|
break; \
|
|
} \
|
|
} \
|
|
if (!found) { \
|
|
nmax2--; \
|
|
} \
|
|
} \
|
|
expr1_->SetRange(nmin1, nmax1); \
|
|
expr2_->SetRange(nmin2, nmax2);
|
|
|
|
void IntIntExprFunctionElement::SetMin(int64_t lower_bound) {
|
|
UPDATE_ELEMENT_INDEX_BOUNDS(value >= lower_bound);
|
|
}
|
|
|
|
void IntIntExprFunctionElement::SetMax(int64_t upper_bound) {
|
|
UPDATE_ELEMENT_INDEX_BOUNDS(value <= upper_bound);
|
|
}
|
|
|
|
void IntIntExprFunctionElement::SetRange(int64_t lower_bound,
|
|
int64_t upper_bound) {
|
|
if (lower_bound > upper_bound) {
|
|
solver()->Fail();
|
|
}
|
|
UPDATE_ELEMENT_INDEX_BOUNDS(value >= lower_bound && value <= upper_bound);
|
|
}
|
|
|
|
#undef UPDATE_ELEMENT_INDEX_BOUNDS
|
|
|
|
void IntIntExprFunctionElement::UpdateSupports() const {
|
|
if (initial_update_ || !expr1_->Contains(min_support1_) ||
|
|
!expr1_->Contains(max_support1_) || !expr2_->Contains(min_support2_) ||
|
|
!expr2_->Contains(max_support2_)) {
|
|
const int64_t emax1 = expr1_->Max();
|
|
const int64_t emax2 = expr2_->Max();
|
|
int64_t min_value = ElementValue(emax1, emax2);
|
|
int64_t max_value = min_value;
|
|
int min_support1 = emax1;
|
|
int max_support1 = emax1;
|
|
int min_support2 = emax2;
|
|
int max_support2 = emax2;
|
|
for (const int64_t index1 : InitAndGetValues(expr1_iterator_)) {
|
|
for (const int64_t index2 : InitAndGetValues(expr2_iterator_)) {
|
|
const int64_t value = ElementValue(index1, index2);
|
|
if (value > max_value) {
|
|
max_value = value;
|
|
max_support1 = index1;
|
|
max_support2 = index2;
|
|
} else if (value < min_value) {
|
|
min_value = value;
|
|
min_support1 = index1;
|
|
min_support2 = index2;
|
|
}
|
|
}
|
|
}
|
|
Solver* s = solver();
|
|
s->SaveAndSetValue(&min_, min_value);
|
|
s->SaveAndSetValue(&min_support1_, min_support1);
|
|
s->SaveAndSetValue(&min_support2_, min_support2);
|
|
s->SaveAndSetValue(&max_, max_value);
|
|
s->SaveAndSetValue(&max_support1_, max_support1);
|
|
s->SaveAndSetValue(&max_support2_, max_support2);
|
|
s->SaveAndSetValue(&initial_update_, false);
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
IntExpr* Solver::MakeElement(Solver::IndexEvaluator2 values,
|
|
IntVar* const index1, IntVar* const index2) {
|
|
CHECK_EQ(this, index1->solver());
|
|
CHECK_EQ(this, index2->solver());
|
|
return RegisterIntExpr(RevAlloc(
|
|
new IntIntExprFunctionElement(this, std::move(values), index1, index2)));
|
|
}
|
|
|
|
// ---------- Generalized element ----------
|
|
|
|
// ----- IfThenElseCt -----
|
|
|
|
class IfThenElseCt : public CastConstraint {
|
|
public:
|
|
IfThenElseCt(Solver* const solver, IntVar* const condition,
|
|
IntExpr* const one, IntExpr* const zero, IntVar* const target)
|
|
: CastConstraint(solver, target),
|
|
condition_(condition),
|
|
zero_(zero),
|
|
one_(one) {}
|
|
|
|
~IfThenElseCt() override {}
|
|
|
|
void Post() override {
|
|
Demon* const demon = solver()->MakeConstraintInitialPropagateCallback(this);
|
|
condition_->WhenBound(demon);
|
|
one_->WhenRange(demon);
|
|
zero_->WhenRange(demon);
|
|
target_var_->WhenRange(demon);
|
|
}
|
|
|
|
void InitialPropagate() override {
|
|
condition_->SetRange(0, 1);
|
|
const int64_t target_var_min = target_var_->Min();
|
|
const int64_t target_var_max = target_var_->Max();
|
|
int64_t new_min = std::numeric_limits<int64_t>::min();
|
|
int64_t new_max = std::numeric_limits<int64_t>::max();
|
|
if (condition_->Max() == 0) {
|
|
zero_->SetRange(target_var_min, target_var_max);
|
|
zero_->Range(&new_min, &new_max);
|
|
} else if (condition_->Min() == 1) {
|
|
one_->SetRange(target_var_min, target_var_max);
|
|
one_->Range(&new_min, &new_max);
|
|
} else {
|
|
if (target_var_max < zero_->Min() || target_var_min > zero_->Max()) {
|
|
condition_->SetValue(1);
|
|
one_->SetRange(target_var_min, target_var_max);
|
|
one_->Range(&new_min, &new_max);
|
|
} else if (target_var_max < one_->Min() || target_var_min > one_->Max()) {
|
|
condition_->SetValue(0);
|
|
zero_->SetRange(target_var_min, target_var_max);
|
|
zero_->Range(&new_min, &new_max);
|
|
} else {
|
|
int64_t zl = 0;
|
|
int64_t zu = 0;
|
|
int64_t ol = 0;
|
|
int64_t ou = 0;
|
|
zero_->Range(&zl, &zu);
|
|
one_->Range(&ol, &ou);
|
|
new_min = std::min(zl, ol);
|
|
new_max = std::max(zu, ou);
|
|
}
|
|
}
|
|
target_var_->SetRange(new_min, new_max);
|
|
}
|
|
|
|
std::string DebugString() const override {
|
|
return absl::StrFormat("(%s ? %s : %s) == %s", condition_->DebugString(),
|
|
one_->DebugString(), zero_->DebugString(),
|
|
target_var_->DebugString());
|
|
}
|
|
|
|
void Accept(ModelVisitor* const visitor) const override {}
|
|
|
|
private:
|
|
IntVar* const condition_;
|
|
IntExpr* const zero_;
|
|
IntExpr* const one_;
|
|
};
|
|
|
|
// ----- IntExprEvaluatorElementCt -----
|
|
|
|
// This constraint implements evaluator(index) == var. It is delayed such
|
|
// that propagation only occurs when all variables have been touched.
|
|
// The range of the evaluator is [range_start, range_end).
|
|
|
|
namespace {
|
|
class IntExprEvaluatorElementCt : public CastConstraint {
|
|
public:
|
|
IntExprEvaluatorElementCt(Solver* s, Solver::Int64ToIntVar evaluator,
|
|
int64_t range_start, int64_t range_end,
|
|
IntVar* index, IntVar* target_var);
|
|
~IntExprEvaluatorElementCt() override {}
|
|
|
|
void Post() override;
|
|
void InitialPropagate() override;
|
|
|
|
void Propagate();
|
|
void Update(int index);
|
|
void UpdateExpr();
|
|
|
|
std::string DebugString() const override;
|
|
void Accept(ModelVisitor* visitor) const override;
|
|
|
|
protected:
|
|
IntVar* const index_;
|
|
|
|
private:
|
|
const Solver::Int64ToIntVar evaluator_;
|
|
const int64_t range_start_;
|
|
const int64_t range_end_;
|
|
int min_support_;
|
|
int max_support_;
|
|
};
|
|
|
|
IntExprEvaluatorElementCt::IntExprEvaluatorElementCt(
|
|
Solver* const s, Solver::Int64ToIntVar evaluator, int64_t range_start,
|
|
int64_t range_end, IntVar* const index, IntVar* const target_var)
|
|
: CastConstraint(s, target_var),
|
|
index_(index),
|
|
evaluator_(std::move(evaluator)),
|
|
range_start_(range_start),
|
|
range_end_(range_end),
|
|
min_support_(-1),
|
|
max_support_(-1) {}
|
|
|
|
void IntExprEvaluatorElementCt::Post() {
|
|
Demon* const delayed_propagate_demon = MakeDelayedConstraintDemon0(
|
|
solver(), this, &IntExprEvaluatorElementCt::Propagate, "Propagate");
|
|
for (int i = range_start_; i < range_end_; ++i) {
|
|
IntVar* const current_var = evaluator_(i);
|
|
current_var->WhenRange(delayed_propagate_demon);
|
|
Demon* const update_demon = MakeConstraintDemon1(
|
|
solver(), this, &IntExprEvaluatorElementCt::Update, "Update", i);
|
|
current_var->WhenRange(update_demon);
|
|
}
|
|
index_->WhenRange(delayed_propagate_demon);
|
|
Demon* const update_expr_demon = MakeConstraintDemon0(
|
|
solver(), this, &IntExprEvaluatorElementCt::UpdateExpr, "UpdateExpr");
|
|
index_->WhenRange(update_expr_demon);
|
|
Demon* const update_var_demon = MakeConstraintDemon0(
|
|
solver(), this, &IntExprEvaluatorElementCt::Propagate, "UpdateVar");
|
|
|
|
target_var_->WhenRange(update_var_demon);
|
|
}
|
|
|
|
void IntExprEvaluatorElementCt::InitialPropagate() { Propagate(); }
|
|
|
|
void IntExprEvaluatorElementCt::Propagate() {
|
|
const int64_t emin = std::max(range_start_, index_->Min());
|
|
const int64_t emax = std::min<int64_t>(range_end_ - 1, index_->Max());
|
|
const int64_t vmin = target_var_->Min();
|
|
const int64_t vmax = target_var_->Max();
|
|
if (emin == emax) {
|
|
index_->SetValue(emin); // in case it was reduced by the above min/max.
|
|
evaluator_(emin)->SetRange(vmin, vmax);
|
|
} else {
|
|
int64_t nmin = emin;
|
|
for (; nmin <= emax; nmin++) {
|
|
// break if the intersection of
|
|
// [evaluator_(nmin)->Min(), evaluator_(nmin)->Max()] and [vmin, vmax]
|
|
// is non-empty.
|
|
IntVar* const nmin_var = evaluator_(nmin);
|
|
if (nmin_var->Min() <= vmax && nmin_var->Max() >= vmin) break;
|
|
}
|
|
int64_t nmax = emax;
|
|
for (; nmin <= nmax; nmax--) {
|
|
// break if the intersection of
|
|
// [evaluator_(nmin)->Min(), evaluator_(nmin)->Max()] and [vmin, vmax]
|
|
// is non-empty.
|
|
IntExpr* const nmax_var = evaluator_(nmax);
|
|
if (nmax_var->Min() <= vmax && nmax_var->Max() >= vmin) break;
|
|
}
|
|
index_->SetRange(nmin, nmax);
|
|
if (nmin == nmax) {
|
|
evaluator_(nmin)->SetRange(vmin, vmax);
|
|
}
|
|
}
|
|
if (min_support_ == -1 || max_support_ == -1) {
|
|
int min_support = -1;
|
|
int max_support = -1;
|
|
int64_t gmin = std::numeric_limits<int64_t>::max();
|
|
int64_t gmax = std::numeric_limits<int64_t>::min();
|
|
for (int i = index_->Min(); i <= index_->Max(); ++i) {
|
|
IntExpr* const var_i = evaluator_(i);
|
|
const int64_t vmin = var_i->Min();
|
|
if (vmin < gmin) {
|
|
gmin = vmin;
|
|
}
|
|
const int64_t vmax = var_i->Max();
|
|
if (vmax > gmax) {
|
|
gmax = vmax;
|
|
}
|
|
}
|
|
solver()->SaveAndSetValue(&min_support_, min_support);
|
|
solver()->SaveAndSetValue(&max_support_, max_support);
|
|
target_var_->SetRange(gmin, gmax);
|
|
}
|
|
}
|
|
|
|
void IntExprEvaluatorElementCt::Update(int index) {
|
|
if (index == min_support_ || index == max_support_) {
|
|
solver()->SaveAndSetValue(&min_support_, -1);
|
|
solver()->SaveAndSetValue(&max_support_, -1);
|
|
}
|
|
}
|
|
|
|
void IntExprEvaluatorElementCt::UpdateExpr() {
|
|
if (!index_->Contains(min_support_) || !index_->Contains(max_support_)) {
|
|
solver()->SaveAndSetValue(&min_support_, -1);
|
|
solver()->SaveAndSetValue(&max_support_, -1);
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
std::string StringifyEvaluatorBare(const Solver::Int64ToIntVar& evaluator,
|
|
int64_t range_start, int64_t range_end) {
|
|
std::string out;
|
|
for (int64_t i = range_start; i < range_end; ++i) {
|
|
if (i != range_start) {
|
|
out += ", ";
|
|
}
|
|
out += absl::StrFormat("%d -> %s", i, evaluator(i)->DebugString());
|
|
}
|
|
return out;
|
|
}
|
|
|
|
std::string StringifyInt64ToIntVar(const Solver::Int64ToIntVar& evaluator,
|
|
int64_t range_begin, int64_t range_end) {
|
|
std::string out;
|
|
if (range_end - range_begin > 10) {
|
|
out = absl::StrFormat(
|
|
"IntToIntVar(%s, ...%s)",
|
|
StringifyEvaluatorBare(evaluator, range_begin, range_begin + 5),
|
|
StringifyEvaluatorBare(evaluator, range_end - 5, range_end));
|
|
} else {
|
|
out = absl::StrFormat(
|
|
"IntToIntVar(%s)",
|
|
StringifyEvaluatorBare(evaluator, range_begin, range_end));
|
|
}
|
|
return out;
|
|
}
|
|
} // namespace
|
|
|
|
std::string IntExprEvaluatorElementCt::DebugString() const {
|
|
return StringifyInt64ToIntVar(evaluator_, range_start_, range_end_);
|
|
}
|
|
|
|
void IntExprEvaluatorElementCt::Accept(ModelVisitor* const visitor) const {
|
|
visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
|
|
visitor->VisitIntegerVariableEvaluatorArgument(
|
|
ModelVisitor::kEvaluatorArgument, evaluator_);
|
|
visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument, index_);
|
|
visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
|
|
target_var_);
|
|
visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
|
|
}
|
|
|
|
// ----- IntExprArrayElementCt -----
|
|
|
|
// This constraint implements vars[index] == var. It is delayed such
|
|
// that propagation only occurs when all variables have been touched.
|
|
|
|
class IntExprArrayElementCt : public IntExprEvaluatorElementCt {
|
|
public:
|
|
IntExprArrayElementCt(Solver* s, std::vector<IntVar*> vars, IntVar* index,
|
|
IntVar* target_var);
|
|
|
|
std::string DebugString() const override;
|
|
void Accept(ModelVisitor* visitor) const override;
|
|
|
|
private:
|
|
const std::vector<IntVar*> vars_;
|
|
};
|
|
|
|
IntExprArrayElementCt::IntExprArrayElementCt(Solver* const s,
|
|
std::vector<IntVar*> vars,
|
|
IntVar* const index,
|
|
IntVar* const target_var)
|
|
: IntExprEvaluatorElementCt(
|
|
s, [this](int64_t idx) { return vars_[idx]; }, 0, vars.size(), index,
|
|
target_var),
|
|
vars_(std::move(vars)) {}
|
|
|
|
std::string IntExprArrayElementCt::DebugString() const {
|
|
int64_t size = vars_.size();
|
|
if (size > 10) {
|
|
return absl::StrFormat(
|
|
"IntExprArrayElement(var array of size %d, %s) == %s", size,
|
|
index_->DebugString(), target_var_->DebugString());
|
|
} else {
|
|
return absl::StrFormat("IntExprArrayElement([%s], %s) == %s",
|
|
JoinDebugStringPtr(vars_, ", "),
|
|
index_->DebugString(), target_var_->DebugString());
|
|
}
|
|
}
|
|
|
|
void IntExprArrayElementCt::Accept(ModelVisitor* const visitor) const {
|
|
visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
|
|
visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
|
|
vars_);
|
|
visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument, index_);
|
|
visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
|
|
target_var_);
|
|
visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
|
|
}
|
|
|
|
// ----- IntExprArrayElementCstCt -----
|
|
|
|
// This constraint implements vars[index] == constant.
|
|
|
|
class IntExprArrayElementCstCt : public Constraint {
|
|
public:
|
|
IntExprArrayElementCstCt(Solver* const s, const std::vector<IntVar*>& vars,
|
|
IntVar* const index, int64_t target)
|
|
: Constraint(s),
|
|
vars_(vars),
|
|
index_(index),
|
|
target_(target),
|
|
demons_(vars.size()) {}
|
|
|
|
~IntExprArrayElementCstCt() override {}
|
|
|
|
void Post() override {
|
|
for (int i = 0; i < vars_.size(); ++i) {
|
|
demons_[i] = MakeConstraintDemon1(
|
|
solver(), this, &IntExprArrayElementCstCt::Propagate, "Propagate", i);
|
|
vars_[i]->WhenDomain(demons_[i]);
|
|
}
|
|
Demon* const index_demon = MakeConstraintDemon0(
|
|
solver(), this, &IntExprArrayElementCstCt::PropagateIndex,
|
|
"PropagateIndex");
|
|
index_->WhenBound(index_demon);
|
|
}
|
|
|
|
void InitialPropagate() override {
|
|
for (int i = 0; i < vars_.size(); ++i) {
|
|
Propagate(i);
|
|
}
|
|
PropagateIndex();
|
|
}
|
|
|
|
void Propagate(int index) {
|
|
if (!vars_[index]->Contains(target_)) {
|
|
index_->RemoveValue(index);
|
|
demons_[index]->inhibit(solver());
|
|
}
|
|
}
|
|
|
|
void PropagateIndex() {
|
|
if (index_->Bound()) {
|
|
vars_[index_->Min()]->SetValue(target_);
|
|
}
|
|
}
|
|
|
|
std::string DebugString() const override {
|
|
return absl::StrFormat("IntExprArrayElement([%s], %s) == %d",
|
|
JoinDebugStringPtr(vars_, ", "),
|
|
index_->DebugString(), target_);
|
|
}
|
|
|
|
void Accept(ModelVisitor* const visitor) const override {
|
|
visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
|
|
visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
|
|
vars_);
|
|
visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
|
|
index_);
|
|
visitor->VisitIntegerArgument(ModelVisitor::kTargetArgument, target_);
|
|
visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
|
|
}
|
|
|
|
private:
|
|
const std::vector<IntVar*> vars_;
|
|
IntVar* const index_;
|
|
const int64_t target_;
|
|
std::vector<Demon*> demons_;
|
|
};
|
|
|
|
// This constraint implements index == position(constant in vars).
|
|
|
|
class IntExprIndexOfCt : public Constraint {
|
|
public:
|
|
IntExprIndexOfCt(Solver* const s, const std::vector<IntVar*>& vars,
|
|
IntVar* const index, int64_t target)
|
|
: Constraint(s),
|
|
vars_(vars),
|
|
index_(index),
|
|
target_(target),
|
|
demons_(vars_.size()),
|
|
index_iterator_(index->MakeHoleIterator(true)) {}
|
|
|
|
~IntExprIndexOfCt() override {}
|
|
|
|
void Post() override {
|
|
for (int i = 0; i < vars_.size(); ++i) {
|
|
demons_[i] = MakeConstraintDemon1(
|
|
solver(), this, &IntExprIndexOfCt::Propagate, "Propagate", i);
|
|
vars_[i]->WhenDomain(demons_[i]);
|
|
}
|
|
Demon* const index_demon = MakeConstraintDemon0(
|
|
solver(), this, &IntExprIndexOfCt::PropagateIndex, "PropagateIndex");
|
|
index_->WhenDomain(index_demon);
|
|
}
|
|
|
|
void InitialPropagate() override {
|
|
for (int i = 0; i < vars_.size(); ++i) {
|
|
if (!index_->Contains(i)) {
|
|
vars_[i]->RemoveValue(target_);
|
|
} else if (!vars_[i]->Contains(target_)) {
|
|
index_->RemoveValue(i);
|
|
demons_[i]->inhibit(solver());
|
|
} else if (vars_[i]->Bound()) {
|
|
index_->SetValue(i);
|
|
demons_[i]->inhibit(solver());
|
|
}
|
|
}
|
|
}
|
|
|
|
void Propagate(int index) {
|
|
if (!vars_[index]->Contains(target_)) {
|
|
index_->RemoveValue(index);
|
|
demons_[index]->inhibit(solver());
|
|
} else if (vars_[index]->Bound()) {
|
|
index_->SetValue(index);
|
|
}
|
|
}
|
|
|
|
void PropagateIndex() {
|
|
const int64_t oldmax = index_->OldMax();
|
|
const int64_t vmin = index_->Min();
|
|
const int64_t vmax = index_->Max();
|
|
for (int64_t value = index_->OldMin(); value < vmin; ++value) {
|
|
vars_[value]->RemoveValue(target_);
|
|
demons_[value]->inhibit(solver());
|
|
}
|
|
for (const int64_t value : InitAndGetValues(index_iterator_)) {
|
|
vars_[value]->RemoveValue(target_);
|
|
demons_[value]->inhibit(solver());
|
|
}
|
|
for (int64_t value = vmax + 1; value <= oldmax; ++value) {
|
|
vars_[value]->RemoveValue(target_);
|
|
demons_[value]->inhibit(solver());
|
|
}
|
|
if (index_->Bound()) {
|
|
vars_[index_->Min()]->SetValue(target_);
|
|
}
|
|
}
|
|
|
|
std::string DebugString() const override {
|
|
return absl::StrFormat("IntExprIndexOf([%s], %s) == %d",
|
|
JoinDebugStringPtr(vars_, ", "),
|
|
index_->DebugString(), target_);
|
|
}
|
|
|
|
void Accept(ModelVisitor* const visitor) const override {
|
|
visitor->BeginVisitConstraint(ModelVisitor::kIndexOf, this);
|
|
visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
|
|
vars_);
|
|
visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
|
|
index_);
|
|
visitor->VisitIntegerArgument(ModelVisitor::kTargetArgument, target_);
|
|
visitor->EndVisitConstraint(ModelVisitor::kIndexOf, this);
|
|
}
|
|
|
|
private:
|
|
const std::vector<IntVar*> vars_;
|
|
IntVar* const index_;
|
|
const int64_t target_;
|
|
std::vector<Demon*> demons_;
|
|
IntVarIterator* const index_iterator_;
|
|
};
|
|
|
|
// Factory helper.
|
|
|
|
Constraint* MakeElementEqualityFunc(Solver* const solver,
|
|
const std::vector<int64_t>& vals,
|
|
IntVar* const index, IntVar* const target) {
|
|
if (index->Bound()) {
|
|
const int64_t val = index->Min();
|
|
if (val < 0 || val >= vals.size()) {
|
|
return solver->MakeFalseConstraint();
|
|
} else {
|
|
return solver->MakeEquality(target, vals[val]);
|
|
}
|
|
} else {
|
|
if (IsIncreasingContiguous(vals)) {
|
|
return solver->MakeEquality(target, solver->MakeSum(index, vals[0]));
|
|
} else {
|
|
return solver->RevAlloc(
|
|
new IntElementConstraint(solver, vals, index, target));
|
|
}
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
Constraint* Solver::MakeIfThenElseCt(IntVar* const condition,
|
|
IntExpr* const then_expr,
|
|
IntExpr* const else_expr,
|
|
IntVar* const target_var) {
|
|
return RevAlloc(
|
|
new IfThenElseCt(this, condition, then_expr, else_expr, target_var));
|
|
}
|
|
|
|
IntExpr* Solver::MakeElement(const std::vector<IntVar*>& vars,
|
|
IntVar* const index) {
|
|
if (index->Bound()) {
|
|
return vars[index->Min()];
|
|
}
|
|
const int size = vars.size();
|
|
if (AreAllBound(vars)) {
|
|
std::vector<int64_t> values(size);
|
|
for (int i = 0; i < size; ++i) {
|
|
values[i] = vars[i]->Value();
|
|
}
|
|
return MakeElement(values, index);
|
|
}
|
|
if (index->Size() == 2 && index->Min() + 1 == index->Max() &&
|
|
index->Min() >= 0 && index->Max() < vars.size()) {
|
|
// Let's get the index between 0 and 1.
|
|
IntVar* const scaled_index = MakeSum(index, -index->Min())->Var();
|
|
IntVar* const zero = vars[index->Min()];
|
|
IntVar* const one = vars[index->Max()];
|
|
const std::string name = absl::StrFormat(
|
|
"ElementVar([%s], %s)", JoinNamePtr(vars, ", "), index->name());
|
|
IntVar* const target = MakeIntVar(std::min(zero->Min(), one->Min()),
|
|
std::max(zero->Max(), one->Max()), name);
|
|
AddConstraint(
|
|
RevAlloc(new IfThenElseCt(this, scaled_index, one, zero, target)));
|
|
return target;
|
|
}
|
|
int64_t emin = std::numeric_limits<int64_t>::max();
|
|
int64_t emax = std::numeric_limits<int64_t>::min();
|
|
std::unique_ptr<IntVarIterator> iterator(index->MakeDomainIterator(false));
|
|
for (const int64_t index_value : InitAndGetValues(iterator.get())) {
|
|
if (index_value >= 0 && index_value < size) {
|
|
emin = std::min(emin, vars[index_value]->Min());
|
|
emax = std::max(emax, vars[index_value]->Max());
|
|
}
|
|
}
|
|
const std::string vname =
|
|
size > 10 ? absl::StrFormat("ElementVar(var array of size %d, %s)", size,
|
|
index->DebugString())
|
|
: absl::StrFormat("ElementVar([%s], %s)",
|
|
JoinNamePtr(vars, ", "), index->name());
|
|
IntVar* const element_var = MakeIntVar(emin, emax, vname);
|
|
AddConstraint(
|
|
RevAlloc(new IntExprArrayElementCt(this, vars, index, element_var)));
|
|
return element_var;
|
|
}
|
|
|
|
IntExpr* Solver::MakeElement(Int64ToIntVar vars, int64_t range_start,
|
|
int64_t range_end, IntVar* argument) {
|
|
const std::string index_name =
|
|
!argument->name().empty() ? argument->name() : argument->DebugString();
|
|
const std::string vname = absl::StrFormat(
|
|
"ElementVar(%s, %s)",
|
|
StringifyInt64ToIntVar(vars, range_start, range_end), index_name);
|
|
IntVar* const element_var =
|
|
MakeIntVar(std::numeric_limits<int64_t>::min(),
|
|
std::numeric_limits<int64_t>::max(), vname);
|
|
IntExprEvaluatorElementCt* evaluation_ct = new IntExprEvaluatorElementCt(
|
|
this, std::move(vars), range_start, range_end, argument, element_var);
|
|
AddConstraint(RevAlloc(evaluation_ct));
|
|
evaluation_ct->Propagate();
|
|
return element_var;
|
|
}
|
|
|
|
Constraint* Solver::MakeElementEquality(const std::vector<int64_t>& vals,
|
|
IntVar* const index,
|
|
IntVar* const target) {
|
|
return MakeElementEqualityFunc(this, vals, index, target);
|
|
}
|
|
|
|
Constraint* Solver::MakeElementEquality(const std::vector<int>& vals,
|
|
IntVar* const index,
|
|
IntVar* const target) {
|
|
return MakeElementEqualityFunc(this, ToInt64Vector(vals), index, target);
|
|
}
|
|
|
|
Constraint* Solver::MakeElementEquality(const std::vector<IntVar*>& vars,
|
|
IntVar* const index,
|
|
IntVar* const target) {
|
|
if (AreAllBound(vars)) {
|
|
std::vector<int64_t> values(vars.size());
|
|
for (int i = 0; i < vars.size(); ++i) {
|
|
values[i] = vars[i]->Value();
|
|
}
|
|
return MakeElementEquality(values, index, target);
|
|
}
|
|
if (index->Bound()) {
|
|
const int64_t val = index->Min();
|
|
if (val < 0 || val >= vars.size()) {
|
|
return MakeFalseConstraint();
|
|
} else {
|
|
return MakeEquality(target, vars[val]);
|
|
}
|
|
} else {
|
|
if (target->Bound()) {
|
|
return RevAlloc(
|
|
new IntExprArrayElementCstCt(this, vars, index, target->Min()));
|
|
} else {
|
|
return RevAlloc(new IntExprArrayElementCt(this, vars, index, target));
|
|
}
|
|
}
|
|
}
|
|
|
|
Constraint* Solver::MakeElementEquality(const std::vector<IntVar*>& vars,
|
|
IntVar* const index, int64_t target) {
|
|
if (AreAllBound(vars)) {
|
|
std::vector<int> valid_indices;
|
|
for (int i = 0; i < vars.size(); ++i) {
|
|
if (vars[i]->Value() == target) {
|
|
valid_indices.push_back(i);
|
|
}
|
|
}
|
|
return MakeMemberCt(index, valid_indices);
|
|
}
|
|
if (index->Bound()) {
|
|
const int64_t pos = index->Min();
|
|
if (pos >= 0 && pos < vars.size()) {
|
|
IntVar* const var = vars[pos];
|
|
return MakeEquality(var, target);
|
|
} else {
|
|
return MakeFalseConstraint();
|
|
}
|
|
} else {
|
|
return RevAlloc(new IntExprArrayElementCstCt(this, vars, index, target));
|
|
}
|
|
}
|
|
|
|
Constraint* Solver::MakeIndexOfConstraint(const std::vector<IntVar*>& vars,
|
|
IntVar* const index, int64_t target) {
|
|
if (index->Bound()) {
|
|
const int64_t pos = index->Min();
|
|
if (pos >= 0 && pos < vars.size()) {
|
|
IntVar* const var = vars[pos];
|
|
return MakeEquality(var, target);
|
|
} else {
|
|
return MakeFalseConstraint();
|
|
}
|
|
} else {
|
|
return RevAlloc(new IntExprIndexOfCt(this, vars, index, target));
|
|
}
|
|
}
|
|
|
|
IntExpr* Solver::MakeIndexExpression(const std::vector<IntVar*>& vars,
|
|
int64_t value) {
|
|
IntExpr* const cache = model_cache_->FindVarArrayConstantExpression(
|
|
vars, value, ModelCache::VAR_ARRAY_CONSTANT_INDEX);
|
|
if (cache != nullptr) {
|
|
return cache->Var();
|
|
} else {
|
|
const std::string name =
|
|
absl::StrFormat("Index(%s, %d)", JoinNamePtr(vars, ", "), value);
|
|
IntVar* const index = MakeIntVar(0, vars.size() - 1, name);
|
|
AddConstraint(MakeIndexOfConstraint(vars, index, value));
|
|
model_cache_->InsertVarArrayConstantExpression(
|
|
index, vars, value, ModelCache::VAR_ARRAY_CONSTANT_INDEX);
|
|
return index;
|
|
}
|
|
}
|
|
} // namespace operations_research
|