diff --git a/base/recordio.cc b/base/recordio.cc index 611403b971..1f39c00485 100644 --- a/base/recordio.cc +++ b/base/recordio.cc @@ -11,6 +11,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include +#include "base/logging.h" #include "base/recordio.h" namespace operations_research { @@ -22,9 +25,40 @@ bool RecordWriter::Close() { return file_->Close(); } +std::string RecordWriter::Compress(std::string const& s) const { + const unsigned long source_size = s.size(); + const char * source = s.c_str(); + + unsigned long dsize = source_size + (source_size * 0.1f) + 16; + char * const destination = new char[dsize]; + + const int result = compress((unsigned char *)destination, + &dsize, + (const unsigned char *)source, + source_size); + + if (result != Z_OK) { + LOG(FATAL) << "Compress error occured! Error code: " << result; + } + return std::string(destination, dsize); +} + RecordReader::RecordReader(File* const file) : file_(file) {} bool RecordReader::Close() { return file_->Close(); } + +void RecordReader::Uncompress(const char* const source, + unsigned long source_size, + char* const output_buffer, + unsigned long output_size) const { + const int result = uncompress((unsigned char *)output_buffer, + &output_size, + (const unsigned char *)source, + source_size); + if(result != Z_OK) { + LOG(FATAL) << "Uncompress error occured! Error code: " << result; + } +} } // namespace operations_research diff --git a/base/recordio.h b/base/recordio.h index e796bc1edd..feede17697 100644 --- a/base/recordio.h +++ b/base/recordio.h @@ -28,18 +28,27 @@ class RecordWriter { static const int kMagicNumber; explicit RecordWriter(File* const file); + template bool WriteProtocolMessage(const P& proto) { - std::string compressed_buffer; - proto.SerializeToString(&compressed_buffer); - const int size = compressed_buffer.length(); + std::string uncompressed_buffer; + proto.SerializeToString(&uncompressed_buffer); + const unsigned long uncompressed_size = uncompressed_buffer.size(); + const std::string compressed_buffer = Compress(uncompressed_buffer); + const unsigned long compressed_size = compressed_buffer.size(); if (file_->Write(&kMagicNumber, sizeof(kMagicNumber)) != sizeof(kMagicNumber)) { return false; } - if (file_->Write(&size, sizeof(size)) != sizeof(size)) { + if (file_->Write(&uncompressed_size, sizeof(uncompressed_size)) != + sizeof(uncompressed_size)) { return false; } - if (file_->Write(compressed_buffer.c_str(), size) != size) { + if (file_->Write(&compressed_size, sizeof(compressed_size)) != + sizeof(compressed_size)) { + return false; + } + if (file_->Write(compressed_buffer.c_str(), compressed_size) != + compressed_size) { return false; } return true; @@ -48,6 +57,7 @@ class RecordWriter { bool Close(); private: + std::string Compress(const std::string& input) const; File* const file_; }; @@ -55,8 +65,10 @@ class RecordWriter { class RecordReader { public: explicit RecordReader(File* const file); + template bool ReadProtocolMessage(P* const proto) { - int size = 0; + unsigned long usize = 0; + unsigned long csize = 0; int magic_number = 0; if (file_->Read(&magic_number, sizeof(magic_number)) != sizeof(magic_number)) { @@ -65,21 +77,32 @@ class RecordReader { if (magic_number != RecordWriter::kMagicNumber) { return false; } - if (file_->Read(&size, sizeof(size)) != sizeof(size)) { + if (file_->Read(&usize, sizeof(usize)) != sizeof(usize)) { return false; } - scoped_array buffer(new char[size + 1]); - if (file_->Read(buffer.get(), size) != size) { + if (file_->Read(&csize, sizeof(csize)) != sizeof(csize)) { return false; } - buffer[size] = 0; - proto->ParseFromArray(buffer.get(), size); + scoped_array compressed_buffer(new char[csize + 1]); + if (file_->Read(compressed_buffer.get(), csize) != csize) { + return false; + } + compressed_buffer[csize] = '\0'; + scoped_array buffer(new char[usize + 1]); + Uncompress(compressed_buffer.get(), usize, buffer.get(), usize); + proto->ParseFromArray(buffer.get(), usize); return true; } + // Closes the underlying file. bool Close(); private: + void Uncompress(const char* const source, + unsigned long source_size, + char* const output_buffer, + unsigned long output_size) const; + File* const file_; }; } // namespace operations_research