[CP-SAT] speed up no_overlap violation code; improve glue-clause sharing

This commit is contained in:
Laurent Perron
2024-06-05 09:10:13 +02:00
parent 5b87d86172
commit 529578ef0f
12 changed files with 207 additions and 81 deletions

View File

@@ -368,6 +368,7 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@com_google_protobuf//:protobuf",
],
@@ -2000,6 +2001,7 @@ cc_library(
"//ortools/util:time_limit",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:log_severity",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",

View File

@@ -46,6 +46,23 @@ int64_t ExprValue(const LinearExpressionProto& expr,
return result;
}
LinearExpressionProto ExprDiff(const LinearExpressionProto& a,
const LinearExpressionProto& b) {
LinearExpressionProto result;
result.set_offset(a.offset() - b.offset());
result.mutable_vars()->Reserve(a.vars().size() + b.vars().size());
result.mutable_coeffs()->Reserve(a.vars().size() + b.vars().size());
for (int i = 0; i < a.vars().size(); ++i) {
result.add_vars(a.vars(i));
result.add_coeffs(a.coeffs(i));
}
for (int i = 0; i < b.vars().size(); ++i) {
result.add_vars(b.vars(i));
result.add_coeffs(-b.coeffs(i));
}
return result;
}
int64_t ExprMin(const LinearExpressionProto& expr, const CpModelProto& model) {
int64_t result = expr.offset();
for (int i = 0; i < expr.vars_size(); ++i) {
@@ -1174,15 +1191,44 @@ CompiledNoOverlapConstraint::CompiledNoOverlapConstraint(
int64_t CompiledNoOverlapConstraint::ComputeViolation(
absl::Span<const int64_t> solution) {
DCHECK_GE(ct_proto().no_overlap().intervals_size(), 2);
if (ct_proto().no_overlap().intervals_size() == 2) {
return NoOverlapMinRepairDistance(
cp_model_.constraints(ct_proto().no_overlap().intervals(0)),
cp_model_.constraints(ct_proto().no_overlap().intervals(1)), solution);
}
return ComputeOverloadArea(ct_proto().no_overlap().intervals(), {}, cp_model_,
solution, 1, events_);
}
NoOverlapBetweenTwoIntervals::NoOverlapBetweenTwoIntervals(
const ConstraintProto& ct_proto, const CpModelProto& cp_model)
: CompiledConstraint(ct_proto) {
CHECK_EQ(ct_proto.no_overlap().intervals().size(), 2);
const ConstraintProto& ct0 =
cp_model.constraints(ct_proto.no_overlap().intervals(0));
const ConstraintProto& ct1 =
cp_model.constraints(ct_proto.no_overlap().intervals(1));
// The more compact the better, hence the size + int[].
num_enforcements_ =
ct0.enforcement_literal().size() + ct1.enforcement_literal().size();
if (num_enforcements_ > 0) {
enforcements_.reset(new int[num_enforcements_]);
int i = 0;
for (const int lit : ct0.enforcement_literal()) enforcements_[i++] = lit;
for (const int lit : ct1.enforcement_literal()) enforcements_[i++] = lit;
}
end_minus_start_1_ = ExprDiff(ct0.interval().end(), ct1.interval().start());
end_minus_start_2_ = ExprDiff(ct1.interval().end(), ct0.interval().start());
}
// Same as NoOverlapMinRepairDistance().
int64_t NoOverlapBetweenTwoIntervals::ComputeViolationInternal(
absl::Span<const int64_t> solution) {
for (int i = 0; i < num_enforcements_; ++i) {
if (!LiteralValue(enforcements_[i], solution)) return 0;
}
const int64_t diff1 = ExprValue(end_minus_start_1_, solution);
const int64_t diff2 = ExprValue(end_minus_start_2_, solution);
return std::max(std::min(diff1, diff2), int64_t{0});
}
// ----- CompiledCumulativeConstraint -----
CompiledCumulativeConstraint::CompiledCumulativeConstraint(
@@ -1571,11 +1617,9 @@ void LsEvaluator::CompileOneConstraint(const ConstraintProto& ct) {
case ConstraintProto::ConstraintCase::kNoOverlap: {
const int size = ct.no_overlap().intervals_size();
if (size <= 1) break;
if (size == 2 ||
size > params_.feasibility_jump_max_expanded_constraint_size()) {
CompiledNoOverlapConstraint* no_overlap =
new CompiledNoOverlapConstraint(ct, cp_model_);
constraints_.emplace_back(no_overlap);
if (size > params_.feasibility_jump_max_expanded_constraint_size()) {
constraints_.emplace_back(
new CompiledNoOverlapConstraint(ct, cp_model_));
} else {
// We expand the no_overlap constraints into a quadratic number of
// disjunctions.
@@ -1595,8 +1639,8 @@ void LsEvaluator::CompileOneConstraint(const ConstraintProto& ct) {
ct.no_overlap().intervals(i));
disj->mutable_no_overlap()->add_intervals(
ct.no_overlap().intervals(j));
CompiledNoOverlapConstraint* disjunction =
new CompiledNoOverlapConstraint(*disj, cp_model_);
NoOverlapBetweenTwoIntervals* disjunction =
new NoOverlapBetweenTwoIntervals(*disj, cp_model_);
constraints_.emplace_back(disjunction);
}
}

View File

@@ -243,8 +243,9 @@ class LinearIncrementalEvaluator {
// View of a generic (non linear) constraint for the LsEvaluator.
//
// TODO(user): Do we add a Update(solution, var, new_value) method ?
// TODO(user): Do we want to support Update(solutions, vars, new_values) ?
// TODO(user): Remove the ct_proto() from here and instead expose a
// UsedVariables(). It is inefficient to use a proto for compiled constraint not
// based on one.
class CompiledConstraint {
public:
explicit CompiledConstraint(const ConstraintProto& ct_proto);
@@ -533,6 +534,36 @@ class CompiledNoOverlapConstraint : public CompiledConstraint {
std::vector<std::pair<int64_t, int64_t>> events_;
};
// Special constraint for no overlap between two intervals.
// We usually expand small no-overlap in n^2 such constraint, so we want to
// be compact and efficient here.
class NoOverlapBetweenTwoIntervals : public CompiledConstraint {
public:
NoOverlapBetweenTwoIntervals(const ConstraintProto& ct_proto,
const CpModelProto& cp_model);
~NoOverlapBetweenTwoIntervals() override = default;
int64_t ComputeViolation(absl::Span<const int64_t> solution) final {
return ComputeViolationInternal(solution);
}
// Note(user): this is the same implementation as the base one, but it
// avoid one virtual call !
int64_t ViolationDelta(
int /*var*/, int64_t /*old_value*/,
absl::Span<const int64_t> solution_with_new_value) final {
return ComputeViolationInternal(solution_with_new_value) - violation();
}
private:
int64_t ComputeViolationInternal(absl::Span<const int64_t> solution);
int num_enforcements_;
std::unique_ptr<int[]> enforcements_;
LinearExpressionProto end_minus_start_1_;
LinearExpressionProto end_minus_start_2_;
};
// The violation of a cumulative is the sum of overloads over time.
class CompiledCumulativeConstraint : public CompiledConstraint {
public:

View File

@@ -8240,26 +8240,30 @@ bool CpModelPresolver::ProcessSetPPCSubset(int subset_c, int superset_c,
}
}
if (best != 0) {
LinearConstraintProto new_ct = superset_ct->linear();
int new_size = 0;
for (int i = 0; i < superset_ct->linear().vars().size(); ++i) {
const int var = superset_ct->linear().vars(i);
int64_t coeff = superset_ct->linear().coeffs(i);
for (int i = 0; i < new_ct.vars().size(); ++i) {
const int var = new_ct.vars(i);
int64_t coeff = new_ct.coeffs(i);
if (tmp_set->contains(var)) {
if (coeff == best) continue; // delete term.
coeff -= best;
}
superset_ct->mutable_linear()->set_vars(new_size, var);
superset_ct->mutable_linear()->set_coeffs(new_size, coeff);
new_ct.set_vars(new_size, var);
new_ct.set_coeffs(new_size, coeff);
++new_size;
}
superset_ct->mutable_linear()->mutable_vars()->Truncate(new_size);
superset_ct->mutable_linear()->mutable_coeffs()->Truncate(new_size);
FillDomainInProto(ReadDomainFromProto(superset_ct->linear())
.AdditionWith(Domain(-best)),
superset_ct->mutable_linear());
context_->UpdateConstraintVariableUsage(superset_c);
context_->UpdateRuleStats("setppc: reduced linear coefficients");
new_ct.mutable_vars()->Truncate(new_size);
new_ct.mutable_coeffs()->Truncate(new_size);
FillDomainInProto(ReadDomainFromProto(new_ct).AdditionWith(Domain(-best)),
&new_ct);
if (!PossibleIntegerOverflow(*context_->working_model, new_ct.vars(),
new_ct.coeffs())) {
*superset_ct->mutable_linear() = std::move(new_ct);
context_->UpdateConstraintVariableUsage(superset_c);
context_->UpdateRuleStats("setppc: reduced linear coefficients");
}
}
return true;

View File

@@ -16,6 +16,7 @@
#include <algorithm>
#include <array>
#include <atomic>
#include <cmath>
#include <cstdint>
#include <cstdlib>
#include <deque>
@@ -50,6 +51,7 @@
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "google/protobuf/arena.h"
#include "google/protobuf/text_format.h"
@@ -1427,12 +1429,16 @@ void RegisterClausesExport(int id, SharedClausesManager* shared_clauses_manager,
return;
}
auto* clause_stream = shared_clauses_manager->GetClauseStream(id);
const int max_lbd =
model->GetOrCreate<SatParameters>()->clause_cleanup_lbd_bound();
// Note that this callback takes no global locks, everything operates on this
// worker's own clause stream, whose lock is only used by this worker, and
// briefly when generating a batch in SharedClausesManager::Synchronize().
auto share_clause = [mapping, clause_stream, clause = std::vector<int>()](
auto share_clause = [mapping, clause_stream, max_lbd,
clause = std::vector<int>()](
int lbd, absl::Span<const Literal> literals) mutable {
if (lbd <= 0 || lbd > 2 || !clause_stream->CanAccept(literals.size())) {
if (lbd <= 0 || lbd > max_lbd ||
!clause_stream->CanAccept(literals.size(), lbd)) {
return;
}
clause.clear();
@@ -2409,7 +2415,8 @@ struct SharedClasses {
!params.interleave_search() || params.num_workers() <= 1;
response->SetSynchronizationMode(always_synchronize);
if (params.share_binary_clauses() && params.num_workers() > 1) {
clauses = std::make_unique<SharedClausesManager>(always_synchronize);
clauses = std::make_unique<SharedClausesManager>(always_synchronize,
absl::Seconds(1));
}
}
@@ -3103,9 +3110,9 @@ class LnsSolver : public SubSolver {
return [task_id, this]() {
if (shared_->SearchIsDone()) return;
// Create a random number generator whose seed depends both on the task_id
// and on the parameters_.random_seed() so that changing the later will
// change the LNS behavior.
// Create a random number generator whose seed depends both on the
// task_id and on the parameters_.random_seed() so that changing the
// later will change the LNS behavior.
const int32_t low = static_cast<int32_t>(task_id);
const int32_t high = static_cast<int32_t>(task_id >> 32);
std::seed_seq seed{low, high, lns_parameters_.random_seed()};
@@ -3234,9 +3241,9 @@ class LnsSolver : public SubSolver {
}
if (neighborhood.is_simple &&
neighborhood.num_relaxed_variables_in_objective == 0) {
// If we didn't relax the objective, there can be no improving solution.
// However, we might have some diversity if they are multiple feasible
// solution.
// If we didn't relax the objective, there can be no improving
// solution. However, we might have some diversity if they are
// multiple feasible solution.
//
// TODO(user): How can we teak the search to favor diversity.
if (generator_->num_consecutive_non_improving_calls() > 10) {
@@ -3353,8 +3360,8 @@ class LnsSolver : public SubSolver {
// Special case if we solved a part of the full problem!
//
// TODO(user): This do not work if they are symmetries loaded into SAT.
// For now we just disable this if there is any symmetry. See for
// TODO(user): This do not work if they are symmetries loaded into
// SAT. For now we just disable this if there is any symmetry. See for
// instance spot5_1401.fzn. Be smarter about that.
//
// The issue is that as we fix level zero variables from a partial
@@ -3389,8 +3396,8 @@ class LnsSolver : public SubSolver {
shared_->model_proto.objective(), solution_values));
}
// Report any feasible solution we have. Optimization: We don't do that
// if we just recovered the base solution.
// Report any feasible solution we have. Optimization: We don't do
// that if we just recovered the base solution.
if (data.status == CpSolverStatus::OPTIMAL ||
data.status == CpSolverStatus::FEASIBLE) {
const std::vector<int64_t> base_solution(
@@ -3574,8 +3581,8 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) {
// Adds first solution subsolvers.
//
// The logic is the following. Before the first solution is found, we have (in
// order):
// The logic is the following. Before the first solution is found, we have
// (in order):
// - num_full_problem_solvers full problem solvers
// - num_workers - num_full_problem_solvers -
// num_dedicated_incomplete_solvers first solution solvers.

View File

@@ -16,9 +16,7 @@
#include <functional>
#include <string>
#include <vector>
#include "ortools/base/types.h"
#include "ortools/sat/cp_model.pb.h"
#include "ortools/sat/model.h"
#include "ortools/sat/sat_parameters.pb.h"

View File

@@ -414,21 +414,25 @@ std::unique_ptr<Graph> GenerateGraphForSymmetryDetection(
// Note(user): This require that intervals appear before they are used.
// We currently enforce this at validation, otherwise we need two passes
// here and in a bunch of other places.
//
// TODO(user): With this graph encoding, we loose the symmetry that the
// dimension x can be swapped with the dimension y. I think it is
// possible to encode this by creating two extra nodes X and
// Y, each connected to all the x and all the y, but I have to think
// more about it.
CHECK_EQ(constraint_node, new_node(color));
std::vector<int64_t> local_color = color;
local_color.push_back(0);
const int size = constraint.no_overlap_2d().x_intervals().size();
const int node_x = new_node(local_color);
const int node_y = new_node(local_color);
local_color.pop_back();
graph->AddArc(constraint_node, node_x);
graph->AddArc(constraint_node, node_y);
local_color.push_back(1);
for (int i = 0; i < size; ++i) {
const int box_node = new_node(local_color);
graph->AddArc(box_node, constraint_node);
const int x = constraint.no_overlap_2d().x_intervals(i);
const int y = constraint.no_overlap_2d().y_intervals(i);
graph->AddArc(interval_constraint_index_to_node.at(x),
constraint_node);
graph->AddArc(interval_constraint_index_to_node.at(x),
interval_constraint_index_to_node.at(y));
graph->AddArc(interval_constraint_index_to_node.at(x), node_x);
graph->AddArc(interval_constraint_index_to_node.at(x), box_node);
graph->AddArc(interval_constraint_index_to_node.at(y), node_y);
graph->AddArc(interval_constraint_index_to_node.at(y), box_node);
}
break;
}

View File

@@ -1313,8 +1313,8 @@ different, but we do not care when they are inactive (represented by being
assigned a zero value).
To implement this constraint, we will collect all values in the initial domain
of all variables and attach Boolean variables for each of them.
This requires reading back the values from the model.
of all variables and attach Boolean variables for each of them. This requires
reading back the values from the model.
### Python code

View File

@@ -301,29 +301,25 @@ void FeasibilityJumpSolver::PerturbateCurrentSolution() {
std::string FeasibilityJumpSolver::OneLineStats() const {
// Restarts, perturbations, and solutions imported.
std::string restart_str;
if (num_restarts_ > 1) {
if (type() == SubSolver::INCOMPLETE) {
absl::StrAppend(&restart_str, " rst{imports:", num_solutions_imported_);
absl::StrAppend(&restart_str, " perturbs:", num_perturbations_, "}");
} else {
absl::StrAppend(&restart_str, " #restarts:", num_restarts_ - 1);
}
if (num_solutions_imported_ > 0) {
absl::StrAppend(&restart_str,
" #solutions_imported:", num_solutions_imported_);
}
if (num_perturbations_ > 0) {
absl::StrAppend(&restart_str, " #perturbations:", num_perturbations_);
}
// Moves and evaluations in the general iterations.
const std::string general_str =
num_general_evals_ == 0 && num_general_moves_ == 0
? ""
: absl::StrCat(" #gen_moves:", FormatCounter(num_general_moves_),
" #gen_evals:", FormatCounter(num_general_evals_));
: absl::StrCat(" gen{mvs:", FormatCounter(num_general_moves_),
" evals:", FormatCounter(num_general_evals_), "}");
const std::string compound_str =
num_compound_moves_ == 0 && move_->NumBacktracks() == 0
? ""
: absl::StrCat(
" #comp_moves:", FormatCounter(num_compound_moves_),
" #backtracks:", FormatCounter(move_->NumBacktracks()));
: absl::StrCat(" comp{mvs:", FormatCounter(num_compound_moves_),
" btracks:", FormatCounter(move_->NumBacktracks()),
"}");
// Improving jumps and infeasible constraints.
const int num_infeasible_cts = evaluator_->NumInfeasibleConstraints();
@@ -335,10 +331,10 @@ std::string FeasibilityJumpSolver::OneLineStats() const {
FormatCounter(evaluator_->NumInfeasibleConstraints()));
return absl::StrCat("batch:", num_batches_, restart_str,
" #lin_moves:", FormatCounter(num_linear_moves_),
" #lin_evals:", FormatCounter(num_linear_evals_),
" lin{mvs:", FormatCounter(num_linear_moves_),
" evals:", FormatCounter(num_linear_evals_), "}",
general_str, compound_str, non_solution_str,
" #weight_updates:", FormatCounter(num_weight_updates_));
" #w_updates:", FormatCounter(num_weight_updates_));
}
std::function<void()> FeasibilityJumpSolver::GenerateTask(int64_t /*task_id*/) {

View File

@@ -30,6 +30,7 @@
#include <vector>
#include "absl/hash/hash.h"
#include "absl/time/time.h"
#include "ortools/base/logging.h"
#include "ortools/base/timer.h"
#if !defined(__PORTABLE_PLATFORM__)
@@ -1121,13 +1122,18 @@ int UniqueClauseStream::NumBufferedLiteralsUpToSize(int max_size) const {
return result;
}
bool UniqueClauseStream::CanAccept(int size) const {
bool UniqueClauseStream::CanAccept(int size, int lbd) const {
absl::MutexLock mutex_lock(&mutex_);
return size > 2 && size <= kMaxClauseSize &&
return size > 2 && size <= kMaxClauseSize && lbd <= lbd_threshold_ &&
clauses_by_size_[size - 3].size() + size <=
kMaxBufferedLiteralsPerSize;
}
void UniqueClauseStream::set_lbd_threshold(int lbd) {
absl::MutexLock mutex_lock(&mutex_);
lbd_threshold_ = lbd;
}
size_t UniqueClauseStream::HashClause(absl::Span<const int> clause,
size_t hash_seed) {
size_t hash = absl::HashOf(hash_seed, clause.size());
@@ -1153,8 +1159,10 @@ int UniqueClauseStream::NumClauses(int size) const {
return clauses_by_size_[size - 3].size() / size;
};
SharedClausesManager::SharedClausesManager(bool always_synchronize)
: always_synchronize_(always_synchronize) {}
SharedClausesManager::SharedClausesManager(bool always_synchronize,
absl::Duration share_frequency)
: always_synchronize_(always_synchronize),
share_frequency_(share_frequency) {}
int SharedClausesManager::RegisterNewId() {
absl::MutexLock mutex_lock(&mutex_);
@@ -1242,6 +1250,9 @@ void SharedClausesManager::Synchronize() {
last_visible_binary_clause_ = added_binary_clauses_.size();
const int num_workers = id_to_clause_stream_.size();
if (num_workers <= 1) return;
if (!share_timer_.IsRunning()) share_timer_.Start();
if (share_timer_.GetDuration() < share_frequency_) return;
share_timer_.Restart();
std::vector<int> ids(num_workers);
for (int size = 3; size < UniqueClauseStream::kMaxClauseSize; ++size) {
ids.clear();
@@ -1276,8 +1287,27 @@ void SharedClausesManager::Synchronize() {
}
}
}
if (all_clauses_.NumBufferedLiterals() >
UniqueClauseStream::kMaxLiteralsPerBatch / 2) {
// Tune LBD threshold for individual workers based on how full the batch and
// worker's buffer is.
const bool underfull = all_clauses_.NumBufferedLiterals() <=
UniqueClauseStream::kMaxLiteralsPerBatch -
UniqueClauseStream::kMaxClauseSize;
for (int id = 0; id < num_workers; ++id) {
UniqueClauseStream& stream = id_to_clause_stream_[id];
const int lbd_threshold = stream.lbd_threshold();
// Half a batch left after sharing! Focus on lower LBD clauses.
const bool overfull = stream.NumBufferedLiterals() >
UniqueClauseStream::kMaxLiteralsPerBatch / 2;
const int new_lbd = std::clamp(lbd_threshold + (overfull ? -1 : underfull),
2, UniqueClauseStream::kMaxClauseSize);
if (new_lbd != lbd_threshold) {
VLOG(2) << id_to_worker_name_[id]
<< " sharing clauses with lbd <= " << new_lbd;
stream.set_lbd_threshold(new_lbd);
}
}
if (all_clauses_.NumBufferedLiterals() > 0) {
batches_.push_back(all_clauses_.NextBatch());
VLOG(2) << "Batch #" << batches_.size() << " w/ " << batches_.back().size()
<< " clauses max size = "

View File

@@ -16,7 +16,6 @@
#include <array>
#include <atomic>
#include <bitset>
#include <cstddef>
#include <cstdint>
#include <deque>
@@ -34,6 +33,7 @@
#include "absl/random/random.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "ortools/base/logging.h"
#include "ortools/base/stl_util.h"
@@ -635,7 +635,13 @@ class UniqueClauseStream {
// Returns true if the stream can accept a clause of the specified size
// without dropping it.
bool CanAccept(int size) const;
bool CanAccept(int size, int lbd) const;
int lbd_threshold() const ABSL_LOCKS_EXCLUDED(mutex_) {
absl::MutexLock lock(&mutex_);
return lbd_threshold_;
}
void set_lbd_threshold(int lbd) ABSL_LOCKS_EXCLUDED(mutex_);
// Computes a hash that is independent of the order of literals in the clause.
static size_t HashClause(absl::Span<const int> clause, size_t hash_seed = 0);
@@ -653,6 +659,7 @@ class UniqueClauseStream {
int NumClauses(int size) const ABSL_SHARED_LOCKS_REQUIRED(mutex_);
mutable absl::Mutex mutex_;
int lbd_threshold_ ABSL_GUARDED_BY(mutex_) = 2;
absl::flat_hash_set<size_t> fingerprints_ ABSL_GUARDED_BY(mutex_);
std::array<std::vector<int>, kMaxClauseSize - 2> clauses_by_size_
ABSL_GUARDED_BY(mutex_);
@@ -667,7 +674,8 @@ class UniqueClauseStream {
// literals can be negative numbers.
class SharedClausesManager {
public:
explicit SharedClausesManager(bool always_synchronize);
explicit SharedClausesManager(bool always_synchronize,
absl::Duration share_frequency);
void AddBinaryClause(int id, int lit1, int lit2);
// Returns new glue clauses.
@@ -721,8 +729,10 @@ class SharedClausesManager {
std::vector<int> id_to_last_finished_batch_ ABSL_GUARDED_BY(mutex_);
std::deque<CompactVectorVector<int>> batches_ ABSL_GUARDED_BY(mutex_);
std::deque<UniqueClauseStream> id_to_clause_stream_ ABSL_GUARDED_BY(mutex_);
WallTimer share_timer_ ABSL_GUARDED_BY(mutex_);
const bool always_synchronize_ = true;
const absl::Duration share_frequency_;
// Stats:
std::vector<int64_t> id_to_clauses_exported_;

View File

@@ -280,7 +280,7 @@ void SharedTreeManager::ProposeSplit(ProtoTrail& path, ProtoLiteral decision) {
<< "/" << nodes.size();
return;
}
if (nodes_.size() >= max_nodes_) {
if (nodes_.size() + 2 > max_nodes_) {
VLOG(2) << "Too many nodes to accept split";
return;
}