From e1ff18054da8bdda3bb82f05138f444afc34735c Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Wed, 17 Mar 2021 14:01:03 +0100 Subject: [PATCH] expand base libraries --- ortools/base/linked_hash_map.h | 637 +++++++++++++++++++++++++++++++++ ortools/base/logging.h | 2 + ortools/base/map_util.h | 30 ++ ortools/base/status_macros.h | 30 +- 4 files changed, 682 insertions(+), 17 deletions(-) create mode 100644 ortools/base/linked_hash_map.h diff --git a/ortools/base/linked_hash_map.h b/ortools/base/linked_hash_map.h new file mode 100644 index 0000000000..7d78a3394e --- /dev/null +++ b/ortools/base/linked_hash_map.h @@ -0,0 +1,637 @@ +// Copyright 2010-2018 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This is a simplistic insertion-ordered map. It behaves similarly to an STL +// map, but only implements a small subset of the map's methods. Internally, we +// just keep a map and a list going in parallel. +// +// This class provides no thread safety guarantees, beyond what you would +// normally see with std::list. +// +// Iterators point into the list and should be stable in the face of +// mutations, except for an iterator pointing to an element that was just +// deleted. +// +// This class supports heterogeneous lookups. +// +#ifndef OR_TOOLS_BASE_LINKED_HASH_MAP_H_ +#define OR_TOOLS_BASE_LINKED_HASH_MAP_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/container/internal/common.h" +#include "ortools/base/logging.h" + +namespace gtl { + +// This holds a list of pair items. This list is what gets +// traversed, and it's iterators from this list that we return from +// begin/end/find. +// +// We also keep a set for find. Since std::list is a +// doubly-linked list, the iterators should remain stable. +template ::hasher, + typename KeyEq = + typename absl::flat_hash_set::key_equal, + typename Alloc = std::allocator>> +class linked_hash_map { + using KeyArgImpl = absl::container_internal::KeyArg< + absl::container_internal::IsTransparent::value && + absl::container_internal::IsTransparent::value>; + // Alias used for heterogeneous lookup functions. + // `key_arg` evaluates to `K` when the functors are transparent and to + // `key_type` otherwise. It permits template argument deduction on `K` for the + // transparent case. + template + using key_arg = typename KeyArgImpl::template type; + + public: + using key_type = Key; + using mapped_type = Value; + using hasher = KeyHash; + using key_equal = KeyEq; + using value_type = std::pair; + using allocator_type = Alloc; + using difference_type = ptrdiff_t; + + private: + using ListType = std::list; + + template + class Wrapped { + template + static const K& ToKey(const K& k) { + return k; + } + static const key_type& ToKey(typename ListType::const_iterator it) { + return it->first; + } + static const key_type& ToKey(typename ListType::iterator it) { + return it->first; + } + + Fn fn_; + + friend linked_hash_map; + + public: + using is_transparent = void; + + Wrapped() = default; + explicit Wrapped(Fn fn) : fn_(std::move(fn)) {} + + template + auto operator()(Args&&... args) const + -> decltype(this->fn_(ToKey(args)...)) { + return fn_(ToKey(args)...); + } + }; + using SetType = + absl::flat_hash_set, + Wrapped, Alloc>; + + class NodeHandle { + public: + using key_type = linked_hash_map::key_type; + using mapped_type = linked_hash_map::mapped_type; + using allocator_type = linked_hash_map::allocator_type; + + constexpr NodeHandle() noexcept = default; + NodeHandle(NodeHandle&& nh) noexcept = default; + ~NodeHandle() = default; + NodeHandle& operator=(NodeHandle&& node) noexcept = default; + bool empty() const noexcept { return list_.empty(); } + explicit operator bool() const noexcept { return !empty(); } + allocator_type get_allocator() const { return list_.get_allocator(); } + const key_type& key() const { return list_.front().first; } + mapped_type& mapped() { return list_.front().second; } + void swap(NodeHandle& nh) noexcept { list_.swap(nh.list_); } + + private: + friend linked_hash_map; + + explicit NodeHandle(ListType list) : list_(std::move(list)) {} + ListType list_; + }; + + template + struct InsertReturnType { + Iterator position; + bool inserted; + NodeType node; + }; + + public: + using iterator = typename ListType::iterator; + using const_iterator = typename ListType::const_iterator; + using reverse_iterator = typename ListType::reverse_iterator; + using const_reverse_iterator = typename ListType::const_reverse_iterator; + using reference = typename ListType::reference; + using const_reference = typename ListType::const_reference; + using size_type = typename ListType::size_type; + using pointer = typename std::allocator_traits::pointer; + using const_pointer = + typename std::allocator_traits::const_pointer; + using node_type = NodeHandle; + using insert_return_type = InsertReturnType; + + linked_hash_map() {} + + explicit linked_hash_map(size_t bucket_count, const hasher& hash = hasher(), + const key_equal& eq = key_equal(), + const allocator_type& alloc = allocator_type()) + : set_(bucket_count, Wrapped(hash), Wrapped(eq), + alloc), + list_(alloc) {} + + linked_hash_map(size_t bucket_count, const hasher& hash, + const allocator_type& alloc) + : linked_hash_map(bucket_count, hash, key_equal(), alloc) {} + + linked_hash_map(size_t bucket_count, const allocator_type& alloc) + : linked_hash_map(bucket_count, hasher(), key_equal(), alloc) {} + + explicit linked_hash_map(const allocator_type& alloc) + : linked_hash_map(0, hasher(), key_equal(), alloc) {} + + template + linked_hash_map(InputIt first, InputIt last, size_t bucket_count = 0, + const hasher& hash = hasher(), + const key_equal& eq = key_equal(), + const allocator_type& alloc = allocator_type()) + : linked_hash_map(bucket_count, hash, eq, alloc) { + insert(first, last); + } + + template + linked_hash_map(InputIt first, InputIt last, size_t bucket_count, + const hasher& hash, const allocator_type& alloc) + : linked_hash_map(first, last, bucket_count, hash, key_equal(), alloc) {} + + template + linked_hash_map(InputIt first, InputIt last, size_t bucket_count, + const allocator_type& alloc) + : linked_hash_map(first, last, bucket_count, hasher(), key_equal(), + alloc) {} + + template + linked_hash_map(InputIt first, InputIt last, const allocator_type& alloc) + : linked_hash_map(first, last, /*bucket_count=*/0, hasher(), key_equal(), + alloc) {} + + linked_hash_map(std::initializer_list init, + size_t bucket_count = 0, const hasher& hash = hasher(), + const key_equal& eq = key_equal(), + const allocator_type& alloc = allocator_type()) + : linked_hash_map(init.begin(), init.end(), bucket_count, hash, eq, + alloc) {} + + linked_hash_map(std::initializer_list init, size_t bucket_count, + const hasher& hash, const allocator_type& alloc) + : linked_hash_map(init, bucket_count, hash, key_equal(), alloc) {} + + linked_hash_map(std::initializer_list init, size_t bucket_count, + const allocator_type& alloc) + : linked_hash_map(init, bucket_count, hasher(), key_equal(), alloc) {} + + linked_hash_map(std::initializer_list init, + const allocator_type& alloc) + : linked_hash_map(init, /*bucket_count=*/0, hasher(), key_equal(), + alloc) {} + + linked_hash_map(const linked_hash_map& other) + : linked_hash_map(other.bucket_count(), other.hash_function(), + other.key_eq(), other.get_allocator()) { + CopyFrom(other); + } + + linked_hash_map(const linked_hash_map& other, const allocator_type& alloc) + : linked_hash_map(other.bucket_count(), other.hash_function(), + other.key_eq(), alloc) { + CopyFrom(other); + } + + linked_hash_map(linked_hash_map&& other) noexcept + : set_(std::move(other.set_)), list_(std::move(other.list_)) { + // Since the list and set must agree for other to end up "valid", + // explicitly clear them. + other.set_.clear(); + other.list_.clear(); + } + + linked_hash_map(linked_hash_map&& other, const allocator_type& alloc) + : linked_hash_map(0, other.hash_function(), other.key_eq(), alloc) { + if (get_allocator() == other.get_allocator()) { + *this = std::move(other); + } else { + CopyFrom(std::move(other)); + } + } + + linked_hash_map& operator=(const linked_hash_map& other) { + if (this == &other) return *this; + // Make a new set, with other's hash/eq/alloc. + set_ = SetType(other.bucket_count(), other.set_.hash_function(), + other.set_.key_eq(), other.get_allocator()); + // Copy the list, with other's allocator. + list_ = ListType(other.get_allocator()); + CopyFrom(other); + return *this; + } + + linked_hash_map& operator=(linked_hash_map&& other) noexcept { + // underlying containers will handle progagate_on_container_move details + set_ = std::move(other.set_); + list_ = std::move(other.list_); + other.set_.clear(); + other.list_.clear(); + return *this; + } + + linked_hash_map& operator=(std::initializer_list values) { + clear(); + insert(values.begin(), values.end()); + return *this; + } + + // Derive size_ from set_, as list::size might be O(N). + size_type size() const { return set_.size(); } + size_type max_size() const noexcept { return ~size_type{}; } + bool empty() const { return set_.empty(); } + + // Iteration is list-like, in insertion order. + // These are all forwarded. + iterator begin() { return list_.begin(); } + iterator end() { return list_.end(); } + const_iterator begin() const { return list_.begin(); } + const_iterator end() const { return list_.end(); } + const_iterator cbegin() const { return list_.cbegin(); } + const_iterator cend() const { return list_.cend(); } + reverse_iterator rbegin() { return list_.rbegin(); } + reverse_iterator rend() { return list_.rend(); } + const_reverse_iterator rbegin() const { return list_.rbegin(); } + const_reverse_iterator rend() const { return list_.rend(); } + const_reverse_iterator crbegin() const { return list_.crbegin(); } + const_reverse_iterator crend() const { return list_.crend(); } + reference front() { return list_.front(); } + reference back() { return list_.back(); } + const_reference front() const { return list_.front(); } + const_reference back() const { return list_.back(); } + + void pop_front() { erase(begin()); } + void pop_back() { erase(std::prev(end())); } + + ABSL_ATTRIBUTE_REINITIALIZES void clear() { + set_.clear(); + list_.clear(); + } + + void reserve(size_t n) { set_.reserve(n); } + size_t capacity() const { return set_.capacity(); } + size_t bucket_count() const { return set_.bucket_count(); } + float load_factor() const { return set_.load_factor(); } + + hasher hash_function() const { return set_.hash_function().fn_; } + key_equal key_eq() const { return set_.key_eq().fn_; } + allocator_type get_allocator() const { return list_.get_allocator(); } + + template + size_type erase(const key_arg& key) { + auto found = set_.find(key); + if (found == set_.end()) return 0; + auto list_it = *found; + // Erase set entry first since it refers to the list element. + set_.erase(found); + list_.erase(list_it); + return 1; + } + + iterator erase(const_iterator position) { + auto found = set_.find(position); + CHECK(*found == position) << "Inconsistent iterator for set and list, " + "or the iterator is invalid."; + set_.erase(found); + return list_.erase(position); + } + + iterator erase(iterator position) { + return erase(static_cast(position)); + } + + iterator erase(iterator first, iterator last) { + while (first != last) first = erase(first); + return first; + } + + iterator erase(const_iterator first, const_iterator last) { + while (first != last) first = erase(first); + if (first == end()) return end(); + return *set_.find(first); + } + + template + iterator find(const key_arg& key) { + auto found = set_.find(key); + if (found == set_.end()) return end(); + return *found; + } + + template + const_iterator find(const key_arg& key) const { + auto found = set_.find(key); + if (found == set_.end()) return end(); + return *found; + } + + template + size_type count(const key_arg& key) const { + return contains(key) ? 1 : 0; + } + template + bool contains(const key_arg& key) const { + return set_.contains(key); + } + + template + mapped_type& at(const key_arg& key) { + auto it = find(key); + if (ABSL_PREDICT_FALSE(it == end())) { + LOG(FATAL) << "linked_hash_map::at failed bounds check"; + } + return it->second; + } + + template + const mapped_type& at(const key_arg& key) const { + return const_cast(this)->at(key); + } + + template + std::pair equal_range(const key_arg& key) { + auto iter = set_.find(key); + if (iter == set_.end()) return {end(), end()}; + return {*iter, std::next(*iter)}; + } + + template + std::pair equal_range( + const key_arg& key) const { + auto iter = set_.find(key); + if (iter == set_.end()) return {end(), end()}; + return {*iter, std::next(*iter)}; + } + + template + mapped_type& operator[](const key_arg& key) { + return LazyEmplaceInternal(key).first->second; + } + + template + mapped_type& operator[](key_arg&& key) { + // K* = nullptr parameter above. + return LazyEmplaceInternal(std::forward(key)).first->second; + } + + std::pair insert(const value_type& v) { + return InsertInternal(v); + } + std::pair insert(value_type&& v) { // NOLINT(build/c++11) + return InsertInternal(std::move(v)); + } + + iterator insert(const_iterator, const value_type& v) { + return insert(v).first; + } + iterator insert(const_iterator, value_type&& v) { + return insert(std::move(v)).first; + } + + void insert(std::initializer_list ilist) { + insert(ilist.begin(), ilist.end()); + } + + template + void insert(InputIt first, InputIt last) { + for (; first != last; ++first) insert(*first); + } + + insert_return_type insert(node_type&& node) { + if (!node) return {end(), false, node_type()}; + auto itr = find(node.key()); + if (itr != end()) return {itr, false, std::move(node)}; + list_.splice(list_.end(), node.list_); + set_.insert(--list_.end()); + return {--list_.end(), true, node_type()}; + } + + iterator insert(const_iterator, node_type&& node) { + return insert(std::move(node)).first; + } + + // The last two template parameters ensure that both arguments are rvalues + // (lvalue arguments are handled by the overloads below). This is necessary + // for supporting bitfield arguments. + // + // union { int n : 1; }; + // linked_hash_map m; + // m.insert_or_assign(n, n); + template + std::pair insert_or_assign(key_arg&& k, V&& v) { + return InsertOrAssignInternal(std::forward(k), std::forward(v)); + } + + template + std::pair insert_or_assign(key_arg&& k, const V& v) { + return InsertOrAssignInternal(std::forward(k), v); + } + + template + std::pair insert_or_assign(const key_arg& k, V&& v) { + return InsertOrAssignInternal(k, std::forward(v)); + } + + template + std::pair insert_or_assign(const key_arg& k, const V& v) { + return InsertOrAssignInternal(k, v); + } + + template + iterator insert_or_assign(const_iterator, key_arg&& k, V&& v) { + return insert_or_assign(std::forward(k), std::forward(v)).first; + } + + template + iterator insert_or_assign(const_iterator, key_arg&& k, const V& v) { + return insert_or_assign(std::forward(k), v).first; + } + + template + iterator insert_or_assign(const_iterator, const key_arg& k, V&& v) { + return insert_or_assign(k, std::forward(v)).first; + } + + template + iterator insert_or_assign(const_iterator, const key_arg& k, const V& v) { + return insert_or_assign(k, v).first; + } + + template + std::pair emplace(Args&&... args) { + ListType node_donor; + auto list_iter = + node_donor.emplace(node_donor.end(), std::forward(args)...); + auto ins = set_.insert(list_iter); + if (!ins.second) return {*ins.first, false}; + list_.splice(list_.end(), node_donor, list_iter); + return {list_iter, true}; + } + + template + iterator try_emplace(const_iterator, key_arg&& k, Args&&... args) { + return try_emplace(std::forward(k), std::forward(args)...).first; + } + + template + iterator emplace_hint(const_iterator, Args&&... args) { + return emplace(std::forward(args)...).first; + } + + template + std::pair try_emplace(key_arg&& key, Args&&... args) { + return LazyEmplaceInternal(std::forward>(key), + std::forward(args)...); + } + + template + void merge(linked_hash_map& src) { + auto itr = src.list_.begin(); + while (itr != src.list_.end()) { + if (contains(itr->first)) { + ++itr; + } else { + insert(src.extract(itr++)); + } + } + } + + template + void merge(linked_hash_map&& src) { + merge(src); + } + + node_type extract(const_iterator position) { + set_.erase(position->first); + ListType extracted_node_list; + extracted_node_list.splice(extracted_node_list.end(), list_, position); + return node_type(std::move(extracted_node_list)); + } + + template , int> = 0> + node_type extract(const key_arg& key) { + auto it = find(key); + return it == end() ? node_type() : extract(const_iterator{it}); + } + + template + std::pair try_emplace(const key_arg& key, Args&&... args) { + return LazyEmplaceInternal(key, std::forward(args)...); + } + + void swap(linked_hash_map& other) { + using std::swap; + swap(set_, other.set_); + swap(list_, other.list_); + } + + friend bool operator==(const linked_hash_map& a, const linked_hash_map& b) { + if (a.size() != b.size()) return false; + const linked_hash_map* outer = &a; + const linked_hash_map* inner = &b; + if (outer->capacity() > inner->capacity()) std::swap(outer, inner); + for (const value_type& elem : *outer) { + auto it = inner->find(elem.first); + if (it == inner->end()) return false; + if (it->second != elem.second) return false; + } + + return true; + } + + friend bool operator!=(const linked_hash_map& a, const linked_hash_map& b) { + return !(a == b); + } + + void rehash(size_t n) { set_.rehash(n); } + + private: + template + void CopyFrom(Other&& other) { + for (auto& elem : other.list_) { + set_.insert(list_.insert(list_.end(), std::move(elem))); + } + DCHECK_EQ(set_.size(), list_.size()) << "Set and list are inconsistent."; + } + + template + std::pair InsertInternal(U&& pair) { // NOLINT(build/c++11) + auto iter = set_.find(pair.first); + if (iter != set_.end()) return {*iter, false}; + auto list_iter = list_.insert(list_.end(), std::forward(pair)); + auto inserted = set_.insert(list_iter); + DCHECK(inserted.second); + return {list_iter, true}; + } + + template + std::pair InsertOrAssignInternal(K&& k, V&& v) { + auto iter = set_.find(k); + if (iter != set_.end()) { + (*iter)->second = std::forward(v); + return {*iter, false}; + } + return LazyEmplaceInternal(std::forward(k), std::forward(v)); + } + + template + std::pair LazyEmplaceInternal(K&& key, Args&&... args) { + bool constructed = false; + auto set_iter = + set_.lazy_emplace(key, [this, &constructed, &key, &args...](auto ctor) { + auto list_iter = + list_.emplace(list_.end(), std::piecewise_construct, + std::forward_as_tuple(std::forward(key)), + std::forward_as_tuple(std::forward(args)...)); + constructed = true; + ctor(list_iter); + }); + return {*set_iter, constructed}; + } + + // The set component, used for speedy lookups. + SetType set_; + + // The list component, used for maintaining insertion order. + ListType list_; +}; + +} // namespace gtl + +#endif // OR_TOOLS_BASE_LINKED_HASH_MAP_H_ diff --git a/ortools/base/logging.h b/ortools/base/logging.h index f8880421a5..dc0fd34551 100644 --- a/ortools/base/logging.h +++ b/ortools/base/logging.h @@ -36,8 +36,10 @@ #include "ortools/base/vlog_is_on.h" #define QCHECK CHECK +#define QCHECK_EQ CHECK_EQ #define ABSL_DIE_IF_NULL CHECK_NOTNULL #define CHECK_OK(x) CHECK((x).ok()) +#define QCHECK_OK CHECK_OK // used by or-tools non C++ ports to bridge with the C++ layer. void FixFlagsAndEnvironmentForSwig(); diff --git a/ortools/base/map_util.h b/ortools/base/map_util.h index 86475638eb..d3df09a820 100644 --- a/ortools/base/map_util.h +++ b/ortools/base/map_util.h @@ -22,6 +22,9 @@ namespace gtl { // Perform a lookup in a std::map or std::unordered_map. // If the key is present in the map then the value associated with that // key is returned, otherwise the value passed as a default is returned. +// +// Prefer the two-argument form unless you need to specify a custom default +// value (i.e., one that is not equal to a value-initialized instance). template const typename Collection::value_type::second_type& FindWithDefault( const Collection& collection, @@ -34,6 +37,22 @@ const typename Collection::value_type::second_type& FindWithDefault( return it->second; } +// Returns a const reference to the value associated with the given key if it +// exists, otherwise returns a const reference to a value-initialized object +// that is never destroyed. +template +const typename Collection::value_type::second_type& FindWithDefault( + const Collection& collection, + const typename Collection::value_type::first_type& key) { + static const typename Collection::value_type::second_type* const + default_value = new typename Collection::value_type::second_type{}; + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + return *default_value; + } + return it->second; +} + // Perform a lookup in a std::map or std::unordered_map. // If the key is present a const pointer to the associated value is returned, // otherwise a NULL pointer is returned. @@ -148,6 +167,17 @@ void InsertOrDie(Collection* const collection, << "duplicate key: " << key; } +// Inserts a key into a map with the default value or dies. Returns a reference +// to the inserted element. +template +auto& InsertKeyOrDie(Collection* const collection, + const typename Collection::value_type::first_type& key) { + auto [it, did_insert] = collection->insert(typename Collection::value_type( + key, typename Collection::value_type::second_type())); + CHECK(did_insert) << "duplicate key " << key; + return it->second; +} + // Perform a lookup in std::map or std::unordered_map. // If the key is present and value is non-NULL then a copy of the value // associated with the key is made into *value. Returns whether key was present. diff --git a/ortools/base/status_macros.h b/ortools/base/status_macros.h index c3b45009e6..e6b4d6f13a 100644 --- a/ortools/base/status_macros.h +++ b/ortools/base/status_macros.h @@ -16,6 +16,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "ortools/base/status_builder.h" namespace absl { @@ -24,28 +25,23 @@ namespace absl { // // Example: // RETURN_IF_ERROR(DoThings(4)); -#define RETURN_IF_ERROR(expr) \ - do { \ - /* Using _status below to avoid capture problems if expr is "status". */ \ - const ::absl::Status _status = (expr); \ - if (!_status.ok()) return _status; \ - } while (0) +// RETURN_IF_ERROR(DoThings(5)) << "Additional error context"; +#define RETURN_IF_ERROR(expr) \ + switch (0) \ + case 0: \ + default: \ + if (const ::absl::Status status = (expr); status.ok()) { \ + } else /* NOLINT */ \ + return ::util::StatusBuilder(status) // Internal helper for concatenating macro values. #define STATUS_MACROS_CONCAT_NAME_INNER(x, y) x##y #define STATUS_MACROS_CONCAT_NAME(x, y) STATUS_MACROS_CONCAT_NAME_INNER(x, y) -template -::absl::Status DoAssignOrReturn(T& lhs, ::absl::StatusOr result) { // NOLINT - if (result.ok()) { - lhs = result.value(); - } - return result.status(); -} - -#define ASSIGN_OR_RETURN_IMPL(status, lhs, rexpr) \ - ::absl::Status status = DoAssignOrReturn(lhs, (rexpr)); \ - if (!status.ok()) return status; +#define ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr) \ + auto statusor = (rexpr); \ + RETURN_IF_ERROR(statusor.status()); \ + lhs = *std::move(statusor) // Executes an expression that returns an absl::StatusOr, extracting its value // into the variable defined by lhs (or returning on error).