cache binary and n-ary sum

This commit is contained in:
lperron@google.com
2012-06-14 19:32:41 +00:00
parent f5546aa29e
commit 9b359f5c2c
4 changed files with 128 additions and 39 deletions

View File

@@ -274,15 +274,15 @@ inline uint64 Hash1(ConstIntArray* const values) {
}
}
template <class T> uint64 Hash1(std::vector<T*>* const ptrs) {
if (ptrs->size() == 0) {
template <class T> uint64 Hash1(const std::vector<T*>& ptrs) {
if (ptrs.size() == 0) {
return 0;
} else if (ptrs->size() == 1) {
return Hash1((*ptrs)[0]);
} else if (ptrs.size() == 1) {
return Hash1(ptrs[0]);
} else {
uint64 hash = Hash1((*ptrs)[0]);
for (int i = 1; i < ptrs->size(); ++i) {
hash = hash * i + Hash1((*ptrs)[i]);
uint64 hash = Hash1(ptrs[0]);
for (int i = 1; i < ptrs.size(); ++i) {
hash = hash * i + Hash1(ptrs[i]);
}
return hash;
}
@@ -1442,18 +1442,22 @@ class ModelCache {
};
enum VarVarExpressionType {
VAR_VAR_DIFFERENCE = 0,
VAR_VAR_PROD,
VAR_VAR_MAX,
VAR_VAR_MIN,
VAR_VAR_SUM,
VAR_VAR_IS_EQUAL,
VAR_VAR_IS_EQUAL = 0,
VAR_VAR_IS_NOT_EQUAL,
VAR_VAR_IS_LESS,
VAR_VAR_IS_LESS_OR_EQUAL,
VAR_VAR_EXPRESSION_MAX,
};
enum ExprExprExpressionType {
EXPR_EXPR_DIFFERENCE = 0,
EXPR_EXPR_PROD,
EXPR_EXPR_MAX,
EXPR_EXPR_MIN,
EXPR_EXPR_SUM,
EXPR_EXPR_EXPRESSION_MAX,
};
enum VarConstantConstantExpressionType {
VAR_CONSTANT_CONSTANT_SEMI_CONTINUOUS = 0,
VAR_CONSTANT_CONSTANT_EXPRESSION_MAX,
@@ -1556,6 +1560,19 @@ class ModelCache {
IntVar* const var2,
VarVarExpressionType type) = 0;
// Expr Expr Expressions.
virtual IntExpr* FindExprExprExpression(
IntExpr* const var1,
IntExpr* const var2,
ExprExprExpressionType type) const = 0;
virtual void InsertExprExprExpression(
IntExpr* const expression,
IntExpr* const var1,
IntExpr* const var2,
ExprExprExpressionType type) = 0;
// Var Constant Constant Expressions.
virtual IntExpr* FindVarConstantConstantExpression(
@@ -1587,12 +1604,12 @@ class ModelCache {
// Var Array Expressions.
virtual IntExpr* FindVarArrayExpression(
std::vector<IntVar*>* const vars,
const std::vector<IntVar*>& vars,
VarArrayExpressionType type) const = 0;
virtual void InsertVarArrayExpression(
IntExpr* const expression,
std::vector<IntVar*>* const vars,
const std::vector<IntVar*>& vars,
VarArrayExpressionType type) = 0;
Solver* solver() const;

View File

@@ -848,13 +848,21 @@ IntExpr* Solver::MakeSum(const std::vector<IntVar*>& vars) {
new_max = CapAdd(vars[i]->Max(), new_max);
}
}
IntVar* const sum_var = MakeIntVar(new_min, new_max);
if (new_min != kint64min && new_max != kint64max) {
AddConstraint(RevAlloc(new SumConstraint(this, vars, sum_var)));
IntExpr* const cache =
model_cache_->FindVarArrayExpression(vars, ModelCache::VAR_ARRAY_SUM);
if (cache != NULL) {
return cache->Var();
} else {
AddConstraint(RevAlloc(new SafeSumConstraint(this, vars, sum_var)));
IntVar* const sum_var = MakeIntVar(new_min, new_max);
if (new_min != kint64min && new_max != kint64max) {
AddConstraint(RevAlloc(new SumConstraint(this, vars, sum_var)));
} else {
AddConstraint(RevAlloc(new SafeSumConstraint(this, vars, sum_var)));
}
model_cache_->InsertVarArrayExpression(
sum_var, vars, ModelCache::VAR_ARRAY_SUM);
return sum_var;
}
return sum_var;
}
}
@@ -941,6 +949,15 @@ template<class T> bool AreAllNull(const T* const values, int size) {
return true;
}
template<class T> bool AreAllOnes(const T* const values, int size) {
for (int i = 0; i < size; ++i) {
if (values[i] != 1) {
return false;
}
}
return true;
}
template <class T> bool AreAllBoundOrNull(const IntVar* const * vars,
const T* const values,
int size) {
@@ -2357,23 +2374,27 @@ Constraint* Solver::MakeScalProdLessOrEqual(const std::vector<IntVar*>& vars,
namespace {
template<class T> IntExpr* MakeScalProdFct(Solver* solver,
IntVar* const * vars,
const T* const coefs,
int size) {
if (size == 0 || AreAllNull<T>(coefs, size)) {
const std::vector<IntVar*>& vars,
const std::vector<T>& coefs) {
const int size = vars.size();
if (vars.empty() || AreAllNull<T>(coefs.data(), size)) {
return solver->MakeIntConst(0LL);
}
if (AreAllBoundOrNull(vars, coefs, size)) {
if (AreAllBoundOrNull(vars.data(), coefs.data(), size)) {
int64 cst = 0;
for (int i = 0; i < size; ++i) {
cst += vars[i]->Min() * coefs[i];
}
return solver->MakeIntConst(cst);
}
if (AreAllBooleans(vars, size)) {
if (AreAllPositive<T>(coefs, size)) {
if (AreAllOnes(coefs.data(), size)) {
return solver->MakeSum(vars);
}
if (AreAllBooleans(vars.data(), size)) {
if (AreAllPositive<T>(coefs.data(), size)) {
return solver->RegisterIntExpr(solver->RevAlloc(
new PositiveBooleanScalProd(solver, vars, size, coefs)));
new PositiveBooleanScalProd(
solver, vars.data(), size, coefs.data())));
} else {
// If some coefficients are non-positive, partition coefficients in two
// sets, one for the positive coefficients P and one for the negative
@@ -2429,13 +2450,13 @@ template<class T> IntExpr* MakeScalProdFct(Solver* solver,
IntExpr* Solver::MakeScalProd(const std::vector<IntVar*>& vars,
const std::vector<int64>& coefs) {
DCHECK_EQ(vars.size(), coefs.size());
return MakeScalProdFct<int64>(this, vars.data(), coefs.data(), vars.size());
return MakeScalProdFct<int64>(this, vars, coefs);
}
IntExpr* Solver::MakeScalProd(const std::vector<IntVar*>& vars,
const std::vector<int>& coefs) {
DCHECK_EQ(vars.size(), coefs.size());
return MakeScalProdFct<int>(this, vars.data(), coefs.data(), vars.size());
return MakeScalProdFct<int>(this, vars, coefs);
}

View File

@@ -4618,7 +4618,24 @@ IntExpr* Solver::MakeSum(IntExpr* const l, IntExpr* const r) {
if (l == r) {
return MakeProd(l, 2);
}
return RegisterIntExpr(RevAlloc(new PlusIntExpr(this, l, r)));
IntExpr* cache =
model_cache_->FindExprExprExpression(l, r, ModelCache::EXPR_EXPR_SUM);
if (cache == NULL) {
cache =
model_cache_->FindExprExprExpression(r, l, ModelCache::EXPR_EXPR_SUM);
}
if (cache != NULL) {
return cache;
} else {
IntExpr* const result =
RegisterIntExpr(RevAlloc(new PlusIntExpr(this, l, r)));
model_cache_->InsertExprExprExpression(
result,
l,
r,
ModelCache::EXPR_EXPR_SUM);
return result;
}
}
IntExpr* Solver::MakeSum(IntExpr* const e, int64 v) {

View File

@@ -48,13 +48,13 @@ bool IsEqual(const ConstIntArray*& a1, const ConstIntArray*& a2) {
return a1->Equals(*a2);
}
template<class T> bool IsEqual(std::vector<T*>* const a1,
std::vector<T*>* const a2) {
if (a1->size() != a2->size()) {
template<class T> bool IsEqual(const std::vector<T*>& a1,
const std::vector<T*>& a2) {
if (a1.size() != a2.size()) {
return false;
}
for (int i = 0; i < a1->size(); ++i) {
if ((*a1)[i] != (*a2)[i]) {
for (int i = 0; i < a1.size(); ++i) {
if (a1[i] != a2[i]) {
return false;
}
}
@@ -354,12 +354,13 @@ template <class C, class A1, class A2, class A3> class Cache3 {
class NonReversibleCache : public ModelCache {
public:
typedef Cache1<IntExpr, IntVar*> VarIntExprCache;
typedef Cache1<IntExpr, std::vector<IntVar*>*> VarArrayIntExprCache;
typedef Cache1<IntExpr, std::vector<IntVar*> > VarArrayIntExprCache;
typedef Cache2<Constraint, IntVar*, int64> VarConstantConstraintCache;
typedef Cache2<Constraint, IntVar*, IntVar*> VarVarConstraintCache;
typedef Cache2<IntExpr, IntVar*, int64> VarConstantIntExprCache;
typedef Cache2<IntExpr, IntVar*, IntVar*> VarVarIntExprCache;
typedef Cache2<IntExpr, IntExpr*, IntExpr*> ExprExprIntExprCache;
typedef Cache2<IntExpr, IntVar*, ConstIntArray*> VarConstantArrayIntExprCache;
typedef Cache3<IntExpr, IntVar*, int64, int64>
@@ -389,6 +390,9 @@ class NonReversibleCache : public ModelCache {
for (int i = 0; i < VAR_VAR_EXPRESSION_MAX; ++i) {
var_var_expressions_.push_back(new VarVarIntExprCache);
}
for (int i = 0; i < EXPR_EXPR_EXPRESSION_MAX; ++i) {
expr_expr_expressions_.push_back(new ExprExprIntExprCache);
}
for (int i = 0; i < VAR_CONSTANT_CONSTANT_EXPRESSION_MAX; ++i) {
var_constant_constant_expressions_.push_back(
new VarConstantConstantIntExprCache);
@@ -600,6 +604,35 @@ class NonReversibleCache : public ModelCache {
}
}
// Expr Expr Expression.
virtual IntExpr* FindExprExprExpression(
IntExpr* const var1,
IntExpr* const var2,
ExprExprExpressionType type) const {
DCHECK(var1 != NULL);
DCHECK(var2 != NULL);
DCHECK_GE(type, 0);
DCHECK_LT(type, VAR_VAR_EXPRESSION_MAX);
return expr_expr_expressions_[type]->Find(var1, var2);
}
virtual void InsertExprExprExpression(
IntExpr* const expression,
IntExpr* const var1,
IntExpr* const var2,
ExprExprExpressionType type) {
DCHECK(expression != NULL);
DCHECK(var1 != NULL);
DCHECK(var2 != NULL);
DCHECK_GE(type, 0);
DCHECK_LT(type, VAR_VAR_EXPRESSION_MAX);
if (solver()->state() != Solver::IN_SEARCH &&
expr_expr_expressions_[type]->Find(var1, var2) == NULL) {
expr_expr_expressions_[type]->UnsafeInsert(var1, var2, expression);
}
}
// Var Constant Constant Expression.
virtual IntExpr* FindVarConstantConstantExpression(
@@ -666,7 +699,7 @@ class NonReversibleCache : public ModelCache {
// Var Array Expression.
virtual IntExpr* FindVarArrayExpression(
std::vector<IntVar*>* const vars,
const std::vector<IntVar*>& vars,
VarArrayExpressionType type) const {
DCHECK_GE(type, 0);
DCHECK_LT(type, VAR_ARRAY_EXPRESSION_MAX);
@@ -675,7 +708,7 @@ class NonReversibleCache : public ModelCache {
virtual void InsertVarArrayExpression(
IntExpr* const expression,
std::vector<IntVar*>* const vars,
const std::vector<IntVar*>& vars,
VarArrayExpressionType type) {
DCHECK(expression != NULL);
DCHECK_GE(type, 0);
@@ -695,6 +728,7 @@ class NonReversibleCache : public ModelCache {
std::vector<VarIntExprCache*> var_expressions_;
std::vector<VarConstantIntExprCache*> var_constant_expressions_;
std::vector<VarVarIntExprCache*> var_var_expressions_;
std::vector<ExprExprIntExprCache*> expr_expr_expressions_;
std::vector<VarConstantConstantIntExprCache*> var_constant_constant_expressions_;
std::vector<VarConstantArrayIntExprCache*> var_constant_array_expressions_;
std::vector<VarArrayIntExprCache*> var_array_expressions_;