[CP-SAT] fix bug with reservoir expansion; improve hint processing; improve work sharing

This commit is contained in:
Laurent Perron
2024-11-25 14:49:56 +01:00
parent 7f55c23900
commit df2811878a
14 changed files with 312 additions and 56 deletions

View File

@@ -1720,7 +1720,7 @@ class ConstraintChecker {
current_level += delta.second;
if (current_level < min_level || current_level > max_level) {
VLOG(1) << "Reservoir level " << current_level
<< " is out of bounds at time" << delta.first;
<< " is out of bounds at time: " << delta.first;
return false;
}
}

View File

@@ -169,8 +169,8 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand,
context->UpdateRuleStats("reservoir: expanded using circuit.");
}
void ExpandReservoirUsingPrecedences(int64_t sum_of_positive_demand,
int64_t sum_of_negative_demand,
void ExpandReservoirUsingPrecedences(bool max_level_is_constraining,
bool min_level_is_constraining,
ConstraintProto* reservoir_ct,
PresolveContext* context) {
const ReservoirConstraintProto& reservoir = reservoir_ct->reservoir();
@@ -192,23 +192,21 @@ void ExpandReservoirUsingPrecedences(int64_t sum_of_positive_demand,
// No need for some constraints if the reservoir is just constrained in
// one direction.
if (demand_i > 0 && sum_of_positive_demand <= reservoir.max_level()) {
continue;
}
if (demand_i < 0 && sum_of_negative_demand >= reservoir.min_level()) {
continue;
}
if (demand_i > 0 && !max_level_is_constraining) continue;
if (demand_i < 0 && !min_level_is_constraining) continue;
ConstraintProto* new_ct = context->working_model->add_constraints();
LinearConstraintProto* new_linear = new_ct->mutable_linear();
// Add contributions from previous events.
ConstraintProto* new_cumul = context->working_model->add_constraints();
LinearConstraintProto* new_linear = new_cumul->mutable_linear();
int64_t offset = 0;
// Add contributions from events that happened at time_j <= time_i.
const LinearExpressionProto& time_i = reservoir.time_exprs(i);
for (int j = 0; j < num_events; ++j) {
if (i == j) continue;
const int active_j = is_active_literal(j);
if (context->LiteralIsFalse(active_j)) continue;
const int64_t demand_j = context->FixedValue(reservoir.level_changes(j));
if (demand_j == 0) continue;
// Get or create the literal equivalent to
// active_i && active_j && time[j] <= time[i].
@@ -218,18 +216,8 @@ void ExpandReservoirUsingPrecedences(int64_t sum_of_positive_demand,
const LinearExpressionProto& time_j = reservoir.time_exprs(j);
const int j_lesseq_i = context->GetOrCreateReifiedPrecedenceLiteral(
time_j, time_i, active_j, active_i);
context->working_model->mutable_variables(j_lesseq_i)
->set_name(absl::StrCat(j, " before ", i));
const int64_t demand = context->FixedValue(reservoir.level_changes(j));
if (RefIsPositive(j_lesseq_i)) {
new_linear->add_vars(j_lesseq_i);
new_linear->add_coeffs(demand);
} else {
new_linear->add_vars(NegatedRef(j_lesseq_i));
new_linear->add_coeffs(-demand);
offset -= demand;
}
AddWeightedLiteralToLinearConstraint(j_lesseq_i, demand_j, new_linear,
&offset);
}
// Add contribution from event i.
@@ -237,25 +225,21 @@ void ExpandReservoirUsingPrecedences(int64_t sum_of_positive_demand,
// TODO(user): Alternatively we can mark the whole constraint as enforced
// only if active_i is true. Experiments with both version, right now we
// miss enough benchmarks to conclude.
if (RefIsPositive(active_i)) {
new_linear->add_vars(active_i);
new_linear->add_coeffs(demand_i);
} else {
new_linear->add_vars(NegatedRef(active_i));
new_linear->add_coeffs(-demand_i);
offset -= demand_i;
}
AddWeightedLiteralToLinearConstraint(active_i, demand_i, new_linear,
&offset);
// Note that according to the sign of demand_i, we only need one side.
// We apply the offset here to make sure we use int64_t min and max.
if (demand_i > 0) {
new_linear->add_domain(std::numeric_limits<int64_t>::min());
new_linear->add_domain(reservoir.max_level());
new_linear->add_domain(reservoir.max_level() - offset);
} else {
new_linear->add_domain(reservoir.min_level());
new_linear->add_domain(reservoir.min_level() - offset);
new_linear->add_domain(std::numeric_limits<int64_t>::max());
}
context->CanonicalizeLinearConstraint(new_ct);
// Canonicalize the newly created constraint.
context->CanonicalizeLinearConstraint(new_cumul);
}
reservoir_ct->Clear();
@@ -367,9 +351,10 @@ void ExpandReservoir(ConstraintProto* reservoir_ct, PresolveContext* context) {
} else {
// This one is the faster option usually.
if (all_demands_are_fixed) {
ExpandReservoirUsingPrecedences(sum_of_positive_demand,
sum_of_negative_demand, reservoir_ct,
context);
ExpandReservoirUsingPrecedences(
sum_of_positive_demand > reservoir_ct->reservoir().max_level(),
sum_of_negative_demand < reservoir_ct->reservoir().min_level(),
reservoir_ct, context);
} else {
context->UpdateRuleStats(
"reservoir: skipped expansion due to variable demands");

View File

@@ -92,6 +92,29 @@ TEST(ReservoirExpandTest, NoOptionalAndInitiallyFeasible) {
EXPECT_EQ(27, solutions.size());
}
TEST(ReservoirExpandTest, SimpleSemaphore) {
const CpModelProto initial_model = ParseTestProto(R"pb(
variables { domain: 0 domain: 10 }
variables { domain: 0 domain: 10 }
variables { domain: 0 domain: 1 }
constraints {
reservoir {
max_level: 2
time_exprs { vars: 0 coeffs: 1 }
time_exprs { vars: 1 coeffs: 1 }
active_literals: [ 2, 2 ]
level_changes { offset: -1 }
level_changes { offset: 1 }
}
}
)pb");
absl::btree_set<std::vector<int>> solutions;
const CpSolverResponse response =
SolveAndCheck(initial_model, "", &solutions);
EXPECT_EQ(OPTIMAL, response.status());
EXPECT_EQ(187, solutions.size());
}
TEST(ReservoirExpandTest, GizaReport) {
const CpModelProto initial_model = ParseTestProto(R"pb(
variables { domain: 0 domain: 10 }

View File

@@ -444,7 +444,11 @@ bool CpModelPresolver::PresolveBoolAnd(ConstraintProto* ct) {
return MarkConstraintAsFalse(ct);
}
if (context_->VariableIsUniqueAndRemovable(literal)) {
// This is a "dual" reduction.
changed = true;
context_->UpdateRuleStats(
"bool_and: setting unused literal in rhs to true");
context_->UpdateLiteralSolutionHint(literal, true);
if (!context_->SetLiteralToTrue(literal)) return true;
continue;
}
@@ -6879,8 +6883,8 @@ bool CpModelPresolver::PresolveReservoir(ConstraintProto* ct) {
(num_positives == 0 || num_negatives == 0)) {
// If all level_changes have the same sign, and if the initial state is
// always feasible, we do not care about the order, just the sum.
auto* const sum =
context_->working_model->add_constraints()->mutable_linear();
auto* const sum_ct = context_->working_model->add_constraints();
auto* const sum = sum_ct->mutable_linear();
int64_t fixed_contrib = 0;
for (int i = 0; i < proto.level_changes_size(); ++i) {
const int64_t demand = context_->FixedValue(proto.level_changes(i));
@@ -6898,6 +6902,7 @@ bool CpModelPresolver::PresolveReservoir(ConstraintProto* ct) {
}
sum->add_domain(proto.min_level() - fixed_contrib);
sum->add_domain(proto.max_level() - fixed_contrib);
CanonicalizeLinear(sum_ct);
context_->UpdateRuleStats("reservoir: converted to linear");
return RemoveConstraint(ct);
}

View File

@@ -643,6 +643,20 @@ void AddLinearExpressionToLinearConstraint(const LinearExpressionProto& expr,
}
}
void AddWeightedLiteralToLinearConstraint(int lit, int64_t coeff,
LinearConstraintProto* linear,
int64_t* offset) {
if (coeff == 0) return;
if (RefIsPositive(lit)) {
linear->add_vars(lit);
linear->add_coeffs(coeff);
} else {
linear->add_vars(NegatedRef(lit));
linear->add_coeffs(-coeff);
*offset += coeff;
}
}
bool SafeAddLinearExpressionToLinearConstraint(
const LinearExpressionProto& expr, int64_t coefficient,
LinearConstraintProto* linear) {

View File

@@ -237,6 +237,13 @@ void AddLinearExpressionToLinearConstraint(const LinearExpressionProto& expr,
int64_t coefficient,
LinearConstraintProto* linear);
// Same as above, but with a single term (lit, coeff). Note that lit can be
// negative. The offset is relative to the linear expression (and should be
// negated when added to the rhs of the linear constraint proto).
void AddWeightedLiteralToLinearConstraint(int lit, int64_t coeff,
LinearConstraintProto* linear,
int64_t* offset);
// Same method, but returns if the addition was possible without overflowing.
bool SafeAddLinearExpressionToLinearConstraint(
const LinearExpressionProto& expr, int64_t coefficient,

View File

@@ -106,6 +106,8 @@ std::string ValidateParameters(const SatParameters& params) {
TEST_IN_RANGE(interleave_batch_size, 0, kMaxReasonableParallelism);
TEST_IN_RANGE(shared_tree_open_leaves_per_worker, 1,
kMaxReasonableParallelism);
TEST_IN_RANGE(shared_tree_balance_tolerance, 0,
log2(kMaxReasonableParallelism));
// TODO(user): Consider using annotations directly in the proto for these
// validation. It is however not open sourced.

View File

@@ -2319,7 +2319,7 @@ int PresolveContext::GetOrCreateReifiedPrecedenceLiteral(
if (!LiteralIsTrue(active_i)) {
AddImplication(result, active_i);
}
if (!LiteralIsTrue(active_j)) {
if (!LiteralIsTrue(active_j) && active_i != active_j) {
AddImplication(result, active_j);
}
@@ -2341,7 +2341,7 @@ int PresolveContext::GetOrCreateReifiedPrecedenceLiteral(
if (!LiteralIsTrue(active_i)) {
greater->add_enforcement_literal(active_i);
}
if (!LiteralIsTrue(active_j)) {
if (!LiteralIsTrue(active_j) && active_i != active_j) {
greater->add_enforcement_literal(active_j);
}
CanonicalizeLinearConstraint(greater);

View File

@@ -603,14 +603,18 @@ class PresolveContext {
hint_[var] == (RefIsPositive(lit) ? value : !value);
}
// If the given literal is already hinted, updates its hint.
// Otherwise do nothing.
void UpdateLiteralSolutionHint(int lit, bool value) {
UpdateSolutionHint(PositiveRef(lit), RefIsPositive(lit) ? value : !value);
UpdateSolutionHint(PositiveRef(lit), RefIsPositive(lit) == value ? 1 : 0);
}
// Updates the hint of an existing variable with an existing hint.
// If the given variable is already hinted, updates its hint value.
// Otherwise, do nothing.
void UpdateSolutionHint(int var, int64_t value) {
CHECK(hint_is_loaded_);
CHECK(hint_has_value_[var]);
DCHECK(RefIsPositive(var));
if (!hint_is_loaded_) return;
if (!hint_has_value_[var]) return;
hint_[var] = value;
}

View File

@@ -23,7 +23,7 @@ option java_multiple_files = true;
// Contains the definitions for all the sat algorithm parameters and their
// default values.
//
// NEXT TAG: 305
// NEXT TAG: 306
message SatParameters {
// In some context, like in a portfolio of search, it makes sense to name a
// given parameters set for logging purpose.
@@ -1141,6 +1141,15 @@ message SatParameters {
optional SharedTreeSplitStrategy shared_tree_split_strategy = 239
[default = SPLIT_STRATEGY_AUTO];
// How much deeper compared to the ideal max depth of the tree is considered
// "balanced" enough to still accept a split. Without such a tolerance,
// sometimes the tree can only be split by a single worker, and they may not
// generate a split for some time. In contrast, with a tolerance of 1, at
// least half of all workers should be able to split the tree as soon as a
// split becomes required. This only has an effect on
// SPLIT_STRATEGY_BALANCED_TREE and SPLIT_STRATEGY_DISCREPANCY.
optional int32 shared_tree_balance_tolerance = 305 [default = 1];
// Whether we enumerate all solutions of a problem without objective. Note
// that setting this to true automatically disable some presolve reduction
// that can remove feasible solution. That is it has the same effect as

View File

@@ -20,6 +20,7 @@
#include <cstdlib>
#include <limits>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
@@ -1383,6 +1384,19 @@ void ScanModelForDualBoundStrengthening(
}
namespace {
std::optional<int64_t> GetRefSolutionHint(const PresolveContext& context,
int ref) {
const int var = PositiveRef(ref);
if (!context.VarHasSolutionHint(var)) return std::nullopt;
const int64_t var_hint = context.SolutionHint(var);
return RefIsPositive(ref) ? var_hint : -var_hint;
}
void SetRefSolutionHint(PresolveContext& context, int ref, int hint) {
context.UpdateSolutionHint(PositiveRef(ref),
RefIsPositive(ref) ? hint : -hint);
}
// Decrements the solution hint of `lit` and increments the solution hint of
// `dominating_lit` if both hint values are present and equal to 1 and 0,
// respectively.
@@ -1395,6 +1409,55 @@ void MaybeUpdateLiteralHintFromDominance(PresolveContext& context, int lit,
}
}
// Decrements the solution hint of `ref` by the minimum amount necessary to be
// in `domain`, and increments the solution hint of one or more
// `dominating_variables` by the same total amount. Does nothing if a hint is
// missing or if it is not possible to increment the hint of the dominating
// variables by the amount subtracted from the hint of the dominated variable.
//
// The lower bound of `domain` must be the lower bound of `ref`'s current domain
// in `context`.
void MaybeUpdateRefHintFromDominance(
PresolveContext& context, int ref, const Domain& domain,
const absl::Span<const IntegerVariable> dominating_variables) {
const std::optional<int64_t> ref_hint = GetRefSolutionHint(context, ref);
if (!ref_hint.has_value()) return;
// The quantity to subtract from the solution hint of `ref`.
const int64_t ref_hint_delta = *ref_hint - domain.ClosestValue(*ref_hint);
// If it is 0 there is nothing to do. It might be negative if the solution
// hint is not initially feasible (in which case we can't fix it).
if (ref_hint_delta <= 0) return;
// First step: check that the hint of the dominating variable(s) can be
// incremented by ref_hint_delta (possibly spread over multiple variables),
// and store the new hint values in `new_ref_hint_value_pairs`.
std::vector<std::pair<int, int64_t>> new_ref_hint_value_pairs;
new_ref_hint_value_pairs.push_back({ref, *ref_hint - ref_hint_delta});
int64_t remaining_delta = ref_hint_delta;
for (const IntegerVariable ivar : dominating_variables) {
const int dominating_ref = VarDomination::IntegerVariableToRef(ivar);
const std::optional<int64_t> dominating_ref_hint =
GetRefSolutionHint(context, dominating_ref);
if (!dominating_ref_hint.has_value()) continue;
const int64_t delta =
context.DomainOf(dominating_ref)
.ClosestValue(*dominating_ref_hint + remaining_delta) -
*dominating_ref_hint;
// This might happen if the solution hint is not initially feasible.
if (delta < 0) continue;
new_ref_hint_value_pairs.push_back(
{dominating_ref, *dominating_ref_hint + delta});
remaining_delta -= delta;
if (remaining_delta == 0) break;
}
if (remaining_delta != 0) return;
// Second step: actually update the hints.
for (const auto& [ref, hint] : new_ref_hint_value_pairs) {
SetRefSolutionHint(context, ref, hint);
}
}
bool ProcessAtMostOne(
absl::Span<const int> literals, const std::string& message,
const VarDomination& var_domination,
@@ -1668,7 +1731,10 @@ bool ExploitDominanceRelations(const VarDomination& var_domination,
const int64_t lb = context->MinOf(current_ref);
if (delta + coeff_magnitude > slack) {
context->UpdateRuleStats("domination: fixed to lb.");
if (!context->IntersectDomainWith(current_ref, Domain(lb))) {
const Domain reduced_domain = Domain(lb);
MaybeUpdateRefHintFromDominance(*context, current_ref, reduced_domain,
dominated_by);
if (!context->IntersectDomainWith(current_ref, reduced_domain)) {
return false;
}
@@ -1699,7 +1765,10 @@ bool ExploitDominanceRelations(const VarDomination& var_domination,
}
if (new_ub < context->MaxOf(current_ref)) {
context->UpdateRuleStats("domination: reduced ub.");
if (!context->IntersectDomainWith(current_ref, Domain(lb, new_ub))) {
const Domain reduced_domain = Domain(lb, new_ub);
MaybeUpdateRefHintFromDominance(*context, current_ref, reduced_domain,
dominated_by);
if (!context->IntersectDomainWith(current_ref, reduced_domain)) {
return false;
}
@@ -1807,8 +1876,14 @@ bool ExploitDominanceRelations(const VarDomination& var_domination,
increase_is_forbidden[var] = true;
context->UpdateRuleStats(
"domination: dual strenghtening using dominance");
if (!context->IntersectDomainWith(
ref, Domain(context->MinOf(ref), lb))) {
const Domain reduced_domain = Domain(context->MinOf(ref), lb);
const std::optional<int64_t> ref_hint =
GetRefSolutionHint(*context, ref);
if (ref_hint.has_value()) {
SetRefSolutionHint(*context, ref,
reduced_domain.ClosestValue(*ref_hint));
}
if (!context->IntersectDomainWith(ref, reduced_domain)) {
return false;
}

View File

@@ -13,8 +13,6 @@
#include "ortools/sat/var_domination.h"
#include <string>
#include "gtest/gtest.h"
#include "ortools/base/gmock.h"
#include "ortools/base/parse_test_proto.h"
@@ -226,6 +224,9 @@ TEST(VarDominationTest, ExploitDominanceOfImplicant) {
ScanModelForDominanceDetection(context, &var_dom);
EXPECT_TRUE(ExploitDominanceRelations(var_dom, &context));
const IntegerVariable X = VarDomination::RefToIntegerVariable(0);
const IntegerVariable Y = VarDomination::RefToIntegerVariable(1);
EXPECT_THAT(var_dom.DominatingVariables(X), ElementsAre(NegationOf(Y)));
EXPECT_EQ(context.DomainOf(0).ToString(), "[0]");
EXPECT_EQ(context.DomainOf(1).ToString(), "[0]");
EXPECT_EQ(context.SolutionHint(0), 0);
@@ -272,6 +273,9 @@ TEST(VarDominationTest, ExploitDominanceOfNegatedImplicand) {
ScanModelForDominanceDetection(context, &var_dom);
EXPECT_TRUE(ExploitDominanceRelations(var_dom, &context));
const IntegerVariable X = VarDomination::RefToIntegerVariable(0);
const IntegerVariable Y = VarDomination::RefToIntegerVariable(1);
EXPECT_THAT(var_dom.DominatingVariables(NegationOf(X)), ElementsAre(Y));
EXPECT_EQ(context.DomainOf(0).ToString(), "[1]");
EXPECT_EQ(context.DomainOf(1).ToString(), "[1]");
EXPECT_EQ(context.SolutionHint(0), 1);
@@ -315,12 +319,72 @@ TEST(VarDominationTest, ExploitDominanceInExactlyOne) {
ScanModelForDominanceDetection(context, &var_dom);
EXPECT_TRUE(ExploitDominanceRelations(var_dom, &context));
const IntegerVariable X = VarDomination::RefToIntegerVariable(0);
const IntegerVariable Y = VarDomination::RefToIntegerVariable(1);
EXPECT_THAT(var_dom.DominatingVariables(X), ElementsAre(Y));
EXPECT_EQ(context.DomainOf(0).ToString(), "[0]");
EXPECT_EQ(context.DomainOf(1).ToString(), "[0,1]");
EXPECT_EQ(context.SolutionHint(0), 0);
EXPECT_EQ(context.SolutionHint(1), 1);
}
// Objective: min(X + Y + 2Z)
// Constraint: X + Y + Z <= 10
// X, Y in [-10, 10], Z in [5, 10]
//
// Doing (X++, Z--) or (Y++, Z--) is always beneficial if possible.
TEST(VarDominationTest, ExploitDominanceWithIntegerVariables) {
CpModelProto model_proto = ParseTestProto(R"pb(
variables {
name: "X"
domain: [ -10, 10 ]
}
variables {
name: "Y"
domain: [ -10, 10 ]
}
variables {
name: "Z"
domain: [ 5, 10 ]
}
constraints {
linear {
vars: [ 0, 1, 2 ]
coeffs: [ 1, 1, 1 ]
domain: [ 0, 10 ]
}
}
objective {
vars: [ 0, 1, 2 ]
coeffs: [ 1, 1, 2 ]
}
solution_hint {
vars: [ 0, 1, 2 ]
values: [ 1, 1, 8 ]
}
)pb");
VarDomination var_dom;
Model model;
PresolveContext context(&model, &model_proto, nullptr);
context.InitializeNewDomains();
context.ReadObjectiveFromProto();
context.UpdateNewConstraintsVariableUsage();
context.LoadSolutionHint();
ScanModelForDominanceDetection(context, &var_dom);
EXPECT_TRUE(ExploitDominanceRelations(var_dom, &context));
const IntegerVariable X = VarDomination::RefToIntegerVariable(0);
const IntegerVariable Y = VarDomination::RefToIntegerVariable(1);
const IntegerVariable Z = VarDomination::RefToIntegerVariable(2);
EXPECT_THAT(var_dom.DominatingVariables(Z), ElementsAre(X, Y));
EXPECT_EQ(context.DomainOf(0).ToString(), "[-5]");
EXPECT_EQ(context.DomainOf(1).ToString(), "[0,10]");
EXPECT_EQ(context.DomainOf(2).ToString(), "[5]");
EXPECT_EQ(context.SolutionHint(0), -5);
EXPECT_EQ(context.SolutionHint(1), 10);
EXPECT_EQ(context.SolutionHint(2), 5);
}
// Objective: min(X + 2Y)
// Constraint: BoolOr(X, Y)
//
@@ -369,6 +433,70 @@ TEST(VarDominationTest, ExploitRemainingDominance) {
EXPECT_EQ(context.SolutionHint(1), 0);
}
// Objective: min(X)
// Constraint: -5 <= X + Y <= 5
// Constraint: -15 <= Y + Z <= 15
// X,Y in [-10, 10], Z in [-5, 5]
//
// Doing (X--, Y++) is always beneficial if possible.
TEST(VarDominationTest, ExploitRemainingDominanceWithIntegerVariables) {
CpModelProto model_proto = ParseTestProto(R"pb(
variables {
name: "X"
domain: [ -10, 10 ]
}
variables {
name: "Y"
domain: [ -10, 10 ]
}
variables {
name: "Z"
domain: [ -5, 5 ]
}
objective {
vars: [ 0 ]
coeffs: [ 1 ]
}
constraints {
linear {
vars: [ 0, 1 ]
coeffs: [ 1, 1 ]
domain: [ -5, 5 ]
}
}
constraints {
linear {
vars: [ 1, 2 ]
coeffs: [ 1, 1 ]
domain: [ -15, 15 ]
}
}
solution_hint {
vars: [ 0, 1, 2 ]
values: [ 0, 1, 2 ]
}
)pb");
VarDomination var_dom;
Model model;
PresolveContext context(&model, &model_proto, nullptr);
context.InitializeNewDomains();
context.ReadObjectiveFromProto();
context.UpdateNewConstraintsVariableUsage();
context.LoadSolutionHint();
ScanModelForDominanceDetection(context, &var_dom);
EXPECT_TRUE(ExploitDominanceRelations(var_dom, &context));
const IntegerVariable X = VarDomination::RefToIntegerVariable(0);
const IntegerVariable Y = VarDomination::RefToIntegerVariable(1);
EXPECT_THAT(var_dom.DominatingVariables(X), ElementsAre(Y));
EXPECT_EQ(context.DomainOf(0).ToString(), "[-10,-5]");
EXPECT_EQ(context.DomainOf(1).ToString(), "[5,10]");
EXPECT_EQ(context.DomainOf(2).ToString(), "[5]");
EXPECT_EQ(context.SolutionHint(0), -5);
EXPECT_EQ(context.SolutionHint(1), 6);
EXPECT_EQ(context.SolutionHint(2), 5);
}
// X + Y + Z = 0
// X + 2 Z >= 2
TEST(VarDominationTest, BasicExample1Variation) {

View File

@@ -342,7 +342,8 @@ void SharedTreeManager::ProposeSplit(ProtoTrail& path, ProtoLiteral decision) {
// TODO(user): Need to write up the shape this creates.
// This rule will allow twice as many leaves in the preferred subtree.
if (discrepancy + path.MaxLevel() >
MaxAllowedDiscrepancyPlusDepth(num_desired_leaves)) {
MaxAllowedDiscrepancyPlusDepth(num_desired_leaves) +
params_.shared_tree_balance_tolerance()) {
VLOG(2) << "Too high discrepancy to accept split";
return;
}
@@ -356,7 +357,8 @@ void SharedTreeManager::ProposeSplit(ProtoTrail& path, ProtoLiteral decision) {
}
} else if (params_.shared_tree_split_strategy() ==
SatParameters::SPLIT_STRATEGY_BALANCED_TREE) {
if (path.MaxLevel() + 1 > log2(num_desired_leaves)) {
if (path.MaxLevel() + 1 >
log2(num_desired_leaves) + params_.shared_tree_balance_tolerance()) {
VLOG(2) << "Tree too unbalanced to accept split";
return;
}

View File

@@ -458,6 +458,7 @@ TEST(SharedTreeManagerTest, BalancedSplitTestOneLeafPerWorker) {
params.set_cp_model_presolve(false);
params.set_shared_tree_split_strategy(
SatParameters::SPLIT_STRATEGY_BALANCED_TREE);
params.set_shared_tree_balance_tolerance(0);
model.Add(NewSatParameters(params));
LoadVariables(model_builder.Build(), false, &model);
auto* response_manager = model.GetOrCreate<SharedResponseManager>();
@@ -493,6 +494,7 @@ TEST(SharedTreeManagerTest, BalancedSplitTest) {
params.set_cp_model_presolve(false);
params.set_shared_tree_split_strategy(
SatParameters::SPLIT_STRATEGY_BALANCED_TREE);
params.set_shared_tree_balance_tolerance(0);
model.Add(NewSatParameters(params));
LoadVariables(model_builder.Build(), false, &model);
auto* response_manager = model.GetOrCreate<SharedResponseManager>();