Files
ortools-clone/ortools/math_opt/math_opt_proto_utils.cc
2021-04-11 12:05:38 +02:00

81 lines
3.0 KiB
C++

// Copyright 2010-2021 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/math_opt_proto_utils.h"
#include <stdint.h>
#include <algorithm>
#include <functional>
#include "absl/base/log_severity.h"
#include "absl/container/flat_hash_set.h"
#include "ortools/base/integral_types.h"
#include "ortools/base/logging.h"
#include "ortools/math_opt/callback.pb.h"
#include "ortools/math_opt/sparse_containers.pb.h"
#include "ortools/math_opt/sparse_vector_view.h"
namespace operations_research {
namespace math_opt {
void RemoveSparseDoubleVectorZeros(SparseDoubleVectorProto& sparse_vector) {
CHECK_EQ(sparse_vector.ids_size(), sparse_vector.values_size());
// Keep track of the next index that has not yet been used for a non zero
// value.
int next = 0;
for (const auto [id, value] : MakeView(sparse_vector)) {
// Se use `!(== 0.0)` here so that we keep NaN values for which both `v ==
// 0` and `v != 0` returns false.
if (!(value == 0.0)) {
sparse_vector.set_ids(next, id);
sparse_vector.set_values(next, value);
++next;
}
}
// At the end of the iteration, `next` contains the index of the first unused
// index. This means it contains the number of used elements.
sparse_vector.mutable_ids()->Truncate(next);
sparse_vector.mutable_values()->Truncate(next);
}
SparseVectorFilterPredicate::SparseVectorFilterPredicate(
const SparseVectorFilterProto& filter)
: filter_(filter) {
// We only do this test in non-optimized builds.
if (DEBUG_MODE && filter_.filter_by_ids()) {
// Checks that input filtered_ids are strictly increasing.
// That is: for all i, ids(i) < ids(i+1).
// Hence here we fail if there exists i such that ids(i) >= ids(i+1).
const auto& ids = filter_.filtered_ids();
CHECK(std::adjacent_find(ids.begin(), ids.end(),
std::greater_equal<int64_t>()) == ids.end())
<< "The input filter.filtered_ids must be strictly increasing.";
}
}
absl::flat_hash_set<CallbackEventProto> EventSet(
const CallbackRegistrationProto& callback_registration) {
// Here we don't use for-range loop since for repeated enum fields, the type
// used in C++ is RepeatedField<int>. Using the generated getter instead
// guarantees type safety.
absl::flat_hash_set<CallbackEventProto> events;
for (int i = 0; i < callback_registration.request_registration_size(); ++i) {
events.emplace(callback_registration.request_registration(i));
}
return events;
}
} // namespace math_opt
} // namespace operations_research