Commit d26d804c authored by topjohnwu's avatar topjohnwu

Migrate to generic stream implementation

parent 4f9a25ee
...@@ -21,19 +21,18 @@ using namespace std; ...@@ -21,19 +21,18 @@ using namespace std;
uint32_t dyn_img_hdr::j32 = 0; uint32_t dyn_img_hdr::j32 = 0;
uint64_t dyn_img_hdr::j64 = 0; uint64_t dyn_img_hdr::j64 = 0;
static int64_t one_step(unique_ptr<Compression> &&ptr, int fd, const void *in, size_t size) { static void decompress(format_t type, int fd, const void *in, size_t size) {
ptr->setOut(make_unique<FDOutStream>(fd)); unique_ptr<stream> ptr(get_decoder(type, open_stream<fd_stream>(fd)));
if (!ptr->write(in, size)) ptr->write(in, size);
return -1;
return ptr->finalize();
}
static int64_t decompress(format_t type, int fd, const void *in, size_t size) {
return one_step(unique_ptr<Compression>(get_decoder(type)), fd, in, size);
} }
static int64_t compress(format_t type, int fd, const void *in, size_t size) { static int64_t compress(format_t type, int fd, const void *in, size_t size) {
return one_step(unique_ptr<Compression>(get_encoder(type)), fd, in, size); auto prev = lseek(fd, 0, SEEK_CUR);
unique_ptr<stream> ptr(get_encoder(type, open_stream<fd_stream>(fd)));
ptr->write(in, size);
ptr->close();
auto now = lseek(fd, 0, SEEK_CUR);
return now - prev;
} }
static void dump(void *buf, size_t size, const char *filename) { static void dump(void *buf, size_t size, const char *filename) {
......
...@@ -6,6 +6,13 @@ ...@@ -6,6 +6,13 @@
#include <memory> #include <memory>
#include <functional> #include <functional>
#include <zlib.h>
#include <bzlib.h>
#include <lzma.h>
#include <lz4.h>
#include <lz4frame.h>
#include <lz4hc.h>
#include <logging.h> #include <logging.h>
#include <utils.h> #include <utils.h>
...@@ -14,260 +21,201 @@ ...@@ -14,260 +21,201 @@
using namespace std; using namespace std;
static bool read_file(FILE *fp, const function<void (void *, size_t)> &fn) { #define bwrite filter_stream::write
char buf[4096]; #define bclose filter_stream::close
size_t len;
while ((len = fread(buf, 1, sizeof(buf), fp)))
fn(buf, len);
return true;
}
void decompress(char *infile, const char *outfile) {
bool in_std = strcmp(infile, "-") == 0;
bool rm_in = false;
FILE *in_file = in_std ? stdin : xfopen(infile, "re");
int out_fd = -1;
unique_ptr<Compression> cmp;
read_file(in_file, [&](void *buf, size_t len) -> void {
if (out_fd < 0) {
format_t type = check_fmt(buf, len);
fprintf(stderr, "Detected format: [%s]\n", fmt2name[type]);
if (!COMPRESSED(type))
LOGE("Input file is not a supported compressed type!\n");
cmp.reset(get_decoder(type));
/* If user does not provide outfile, infile has to be either constexpr size_t CHUNK = 0x40000;
* <path>.[ext], or '-'. Outfile will be either <path> or '-'. constexpr size_t LZ4_UNCOMPRESSED = 0x800000;
* If the input does not have proper format, abort */ constexpr size_t LZ4_COMPRESSED = LZ4_COMPRESSBOUND(LZ4_UNCOMPRESSED);
char *ext = nullptr; class cpr_stream : public filter_stream {
if (outfile == nullptr) { public:
outfile = infile; explicit cpr_stream(FILE *fp) : filter_stream(fp) {}
if (!in_std) {
ext = strrchr(infile, '.');
if (ext == nullptr || strcmp(ext, fmt2ext[type]) != 0)
LOGE("Input file is not a supported type!\n");
// Strip out extension and remove input
*ext = '\0';
rm_in = true;
fprintf(stderr, "Decompressing to [%s]\n", outfile);
}
}
out_fd = strcmp(outfile, "-") == 0 ? int read(void *buf, size_t len) final {
STDOUT_FILENO : xopen(outfile, O_WRONLY | O_CREAT | O_TRUNC, 0644); return stream::read(buf, len);
cmp->setOut(make_unique<FDOutStream>(out_fd));
if (ext) *ext = '.';
} }
if (!cmp->write(buf, len))
LOGE("Decompression error!\n");
});
cmp->finalize();
fclose(in_file);
close(out_fd);
if (rm_in) int close() final {
unlink(infile); finish();
} return bclose();
void compress(const char *method, const char *infile, const char *outfile) {
auto it = name2fmt.find(method);
if (it == name2fmt.end())
LOGE("Unsupported compression method: [%s]\n", method);
unique_ptr<Compression> cmp(get_encoder(it->second));
bool in_std = strcmp(infile, "-") == 0;
bool rm_in = false;
FILE *in_file = in_std ? stdin : xfopen(infile, "re");
int out_fd;
if (outfile == nullptr) {
if (in_std) {
out_fd = STDOUT_FILENO;
} else {
/* If user does not provide outfile and infile is not
* STDIN, output to <infile>.[ext] */
char *tmp = new char[strlen(infile) + 5];
sprintf(tmp, "%s%s", infile, fmt2ext[it->second]);
out_fd = xopen(tmp, O_WRONLY | O_CREAT | O_TRUNC, 0644);
fprintf(stderr, "Compressing to [%s]\n", tmp);
delete[] tmp;
rm_in = true;
} }
} else {
out_fd = strcmp(outfile, "-") == 0 ?
STDOUT_FILENO : xopen(outfile, O_WRONLY | O_CREAT | O_TRUNC, 0644);
}
cmp->setOut(make_unique<FDOutStream>(out_fd));
read_file(in_file, [&](void *buf, size_t len) -> void { protected:
if (!cmp->write(buf, len)) // If finish is overridden, destroy should be called in the destructor
LOGE("Compression error!\n"); virtual void finish() {}
}); void destroy() { if (fp) finish(); }
};
cmp->finalize(); class gz_strm : public cpr_stream {
fclose(in_file); public:
close(out_fd); ~gz_strm() override { destroy(); }
if (rm_in) int write(const void *buf, size_t len) override {
unlink(infile); return len ? write(buf, len, Z_NO_FLUSH) : 0;
}
Compression *get_encoder(format_t type) {
switch (type) {
case XZ:
return new XZEncoder();
case LZMA:
return new LZMAEncoder();
case BZIP2:
return new BZEncoder();
case LZ4:
return new LZ4FEncoder();
case LZ4_LEGACY:
return new LZ4Encoder();
case GZIP:
default:
return new GZEncoder();
} }
}
Compression *get_decoder(format_t type) { protected:
switch (type) { enum mode_t {
case XZ: DECODE,
case LZMA: ENCODE
return new LZMADecoder(); } mode;
case BZIP2:
return new BZDecoder();
case LZ4:
return new LZ4FDecoder();
case LZ4_LEGACY:
return new LZ4Decoder();
case GZIP:
default:
return new GZDecoder();
}
}
GZStream::GZStream(int mode) : mode(mode), strm({}) { gz_strm(mode_t mode, FILE *fp) : cpr_stream(fp), mode(mode) {
switch(mode) { switch(mode) {
case 0: case DECODE:
inflateInit2(&strm, 15 | 16); inflateInit2(&strm, 15 | 16);
break; break;
case 1: case ENCODE:
deflateInit2(&strm, 9, Z_DEFLATED, 15 | 16, 8, Z_DEFAULT_STRATEGY); deflateInit2(&strm, 9, Z_DEFLATED, 15 | 16, 8, Z_DEFAULT_STRATEGY);
break; break;
} }
} }
bool GZStream::write(const void *in, size_t size) {
return size ? write(in, size, Z_NO_FLUSH) : true;
}
uint64_t GZStream::finalize() { void finish() override {
write(nullptr, 0, Z_FINISH); write(nullptr, 0, Z_FINISH);
uint64_t total = strm.total_out;
switch(mode) { switch(mode) {
case 0: case DECODE:
inflateEnd(&strm); inflateEnd(&strm);
break; break;
case 1: case ENCODE:
deflateEnd(&strm); deflateEnd(&strm);
break; break;
} }
return total; }
}
bool GZStream::write(const void *in, size_t size, int flush) { private:
int ret; z_stream strm;
strm.next_in = (Bytef *) in; uint8_t outbuf[CHUNK];
strm.avail_in = size;
int write(const void *buf, size_t len, int flush) {
strm.next_in = (Bytef *) buf;
strm.avail_in = len;
do { do {
int code;
strm.next_out = outbuf; strm.next_out = outbuf;
strm.avail_out = sizeof(outbuf); strm.avail_out = sizeof(outbuf);
switch(mode) { switch(mode) {
case 0: case DECODE:
ret = inflate(&strm, flush); code = inflate(&strm, flush);
break; break;
case 1: case ENCODE:
ret = deflate(&strm, flush); code = deflate(&strm, flush);
break; break;
} }
if (ret == Z_STREAM_ERROR) { if (code == Z_STREAM_ERROR) {
LOGW("Gzip %s failed (%d)\n", mode ? "encode" : "decode", ret); LOGW("gzip %s failed (%d)\n", mode ? "encode" : "decode", code);
return false; return -1;
} }
FilterOutStream::write(outbuf, sizeof(outbuf) - strm.avail_out); bwrite(outbuf, sizeof(outbuf) - strm.avail_out);
} while (strm.avail_out == 0); } while (strm.avail_out == 0);
return true; return len;
} }
};
class gz_decoder : public gz_strm {
public:
explicit gz_decoder(FILE *fp) : gz_strm(DECODE, fp) {};
};
BZStream::BZStream(int mode) : mode(mode), strm({}) { class gz_encoder : public gz_strm {
public:
explicit gz_encoder(FILE *fp) : gz_strm(ENCODE, fp) {};
};
class bz_strm : public cpr_stream {
public:
~bz_strm() override { destroy(); }
int write(const void *buf, size_t len) override {
return len ? write(buf, len, BZ_RUN) : 0;
}
protected:
enum mode_t {
DECODE,
ENCODE
} mode;
bz_strm(mode_t mode, FILE *fp) : cpr_stream(fp), mode(mode) {
switch(mode) { switch(mode) {
case 0: case DECODE:
BZ2_bzDecompressInit(&strm, 0, 0); BZ2_bzDecompressInit(&strm, 0, 0);
break; break;
case 1: case ENCODE:
BZ2_bzCompressInit(&strm, 9, 0, 0); BZ2_bzCompressInit(&strm, 9, 0, 0);
break; break;
} }
} }
bool BZStream::write(const void *in, size_t size) {
return size ? write(in, size, BZ_RUN) : true;
}
uint64_t BZStream::finalize() { void finish() override {
if (mode)
write(nullptr, 0, BZ_FINISH);
uint64_t total = ((uint64_t) strm.total_out_hi32 << 32) + strm.total_out_lo32;
switch(mode) { switch(mode) {
case 0: case DECODE:
BZ2_bzDecompressEnd(&strm); BZ2_bzDecompressEnd(&strm);
break; break;
case 1: case ENCODE:
write(nullptr, 0, BZ_FINISH);
BZ2_bzCompressEnd(&strm); BZ2_bzCompressEnd(&strm);
break; break;
} }
return total; }
}
bool BZStream::write(const void *in, size_t size, int flush) { private:
int ret; bz_stream strm;
strm.next_in = (char *) in; char outbuf[CHUNK];
strm.avail_in = size;
int write(const void *buf, size_t len, int flush) {
strm.next_in = (char *) buf;
strm.avail_in = len;
do { do {
int code;
strm.avail_out = sizeof(outbuf); strm.avail_out = sizeof(outbuf);
strm.next_out = outbuf; strm.next_out = outbuf;
switch(mode) { switch(mode) {
case 0: case DECODE:
ret = BZ2_bzDecompress(&strm); code = BZ2_bzDecompress(&strm);
break; break;
case 1: case ENCODE:
ret = BZ2_bzCompress(&strm, flush); code = BZ2_bzCompress(&strm, flush);
break; break;
} }
if (ret < 0) { if (code < 0) {
LOGW("Bzip2 %s failed (%d)\n", mode ? "encode" : "decode", ret); LOGW("bzip2 %s failed (%d)\n", mode ? "encode" : "decode", code);
return false; return -1;
} }
FilterOutStream::write(outbuf, sizeof(outbuf) - strm.avail_out); bwrite(outbuf, sizeof(outbuf) - strm.avail_out);
} while (strm.avail_out == 0); } while (strm.avail_out == 0);
return true; return len;
} }
};
class bz_decoder : public bz_strm {
public:
explicit bz_decoder(FILE *fp) : bz_strm(DECODE, fp) {};
};
class bz_encoder : public bz_strm {
public:
explicit bz_encoder(FILE *fp) : bz_strm(ENCODE, fp) {};
};
class lzma_strm : public cpr_stream {
public:
~lzma_strm() override { destroy(); }
LZMAStream::LZMAStream(int mode) : mode(mode), strm(LZMA_STREAM_INIT) { int write(const void *buf, size_t len) override {
return len ? write(buf, len, LZMA_RUN) : 0;
}
protected:
enum mode_t {
DECODE,
ENCODE_XZ,
ENCODE_LZMA
} mode;
lzma_strm(mode_t mode, FILE *fp) : cpr_stream(fp), mode(mode), strm(LZMA_STREAM_INIT) {
lzma_options_lzma opt; lzma_options_lzma opt;
int ret;
// Initialize preset // Initialize preset
lzma_lzma_preset(&opt, 9); lzma_lzma_preset(&opt, 9);
...@@ -276,83 +224,100 @@ LZMAStream::LZMAStream(int mode) : mode(mode), strm(LZMA_STREAM_INIT) { ...@@ -276,83 +224,100 @@ LZMAStream::LZMAStream(int mode) : mode(mode), strm(LZMA_STREAM_INIT) {
{ .id = LZMA_VLI_UNKNOWN, .options = nullptr }, { .id = LZMA_VLI_UNKNOWN, .options = nullptr },
}; };
lzma_ret ret;
switch(mode) { switch(mode) {
case 0: case DECODE:
ret = lzma_auto_decoder(&strm, UINT64_MAX, 0); ret = lzma_auto_decoder(&strm, UINT64_MAX, 0);
break; break;
case 1: case ENCODE_XZ:
ret = lzma_stream_encoder(&strm, filters, LZMA_CHECK_CRC32); ret = lzma_stream_encoder(&strm, filters, LZMA_CHECK_CRC32);
break; break;
case 2: case ENCODE_LZMA:
ret = lzma_alone_encoder(&strm, &opt); ret = lzma_alone_encoder(&strm, &opt);
break; break;
} }
} }
bool LZMAStream::write(const void *in, size_t size) {
return size ? write(in, size, LZMA_RUN) : true;
}
uint64_t LZMAStream::finalize() { void finish() override {
write(nullptr, 0, LZMA_FINISH); write(nullptr, 0, LZMA_FINISH);
uint64_t total = strm.total_out;
lzma_end(&strm); lzma_end(&strm);
return total; }
}
bool LZMAStream::write(const void *in, size_t size, lzma_action flush) { private:
int ret; lzma_stream strm;
strm.next_in = (uint8_t *) in; uint8_t outbuf[CHUNK];
strm.avail_in = size;
int write(const void *buf, size_t len, lzma_action flush) {
strm.next_in = (uint8_t *) buf;
strm.avail_in = len;
do { do {
strm.avail_out = sizeof(outbuf); strm.avail_out = sizeof(outbuf);
strm.next_out = outbuf; strm.next_out = outbuf;
ret = lzma_code(&strm, flush); int code = lzma_code(&strm, flush);
if (ret != LZMA_OK && ret != LZMA_STREAM_END) { if (code != LZMA_OK && code != LZMA_STREAM_END) {
LOGW("LZMA %s failed (%d)\n", mode ? "encode" : "decode", ret); LOGW("LZMA %s failed (%d)\n", mode ? "encode" : "decode", code);
return false; return -1;
} }
FilterOutStream::write(outbuf, sizeof(outbuf) - strm.avail_out); bwrite(outbuf, sizeof(outbuf) - strm.avail_out);
} while (strm.avail_out == 0); } while (strm.avail_out == 0);
return true; return len;
} }
};
LZ4FDecoder::LZ4FDecoder() : outbuf(nullptr), total(0) {
class lzma_decoder : public lzma_strm {
public:
lzma_decoder(FILE *fp) : lzma_strm(DECODE, fp) {}
};
class xz_encoder : public lzma_strm {
public:
xz_encoder(FILE *fp) : lzma_strm(ENCODE_XZ, fp) {}
};
class lzma_encoder : public lzma_strm {
public:
lzma_encoder(FILE *fp) : lzma_strm(ENCODE_LZMA, fp) {}
};
class LZ4F_decoder : public cpr_stream {
public:
explicit LZ4F_decoder(FILE *fp) : cpr_stream(fp), outbuf(nullptr) {
LZ4F_createDecompressionContext(&ctx, LZ4F_VERSION); LZ4F_createDecompressionContext(&ctx, LZ4F_VERSION);
} }
LZ4FDecoder::~LZ4FDecoder() { ~LZ4F_decoder() override {
LZ4F_freeDecompressionContext(ctx); LZ4F_freeDecompressionContext(ctx);
delete[] outbuf; delete[] outbuf;
} }
bool LZ4FDecoder::write(const void *in, size_t size) { int write(const void *buf, size_t len) override {
auto inbuf = (const uint8_t *) in; auto ret = len;
auto inbuf = reinterpret_cast<const uint8_t *>(buf);
if (!outbuf) if (!outbuf)
read_header(inbuf, size); read_header(inbuf, len);
size_t read, write; size_t read, write;
LZ4F_errorCode_t ret; LZ4F_errorCode_t code;
do { do {
read = size; read = len;
write = outCapacity; write = outCapacity;
ret = LZ4F_decompress(ctx, outbuf, &write, inbuf, &read, nullptr); code = LZ4F_decompress(ctx, outbuf, &write, inbuf, &read, nullptr);
if (LZ4F_isError(ret)) { if (LZ4F_isError(code)) {
LOGW("LZ4 decode error: %s\n", LZ4F_getErrorName(ret)); LOGW("LZ4F decode error: %s\n", LZ4F_getErrorName(code));
return false; return -1;
} }
size -= read; len -= read;
inbuf += read; inbuf += read;
total += write; bwrite(outbuf, write);
FilterOutStream::write(outbuf, write); } while (len != 0 || write != 0);
} while (size != 0 || write != 0); return ret;
return true; }
}
uint64_t LZ4FDecoder::finalize() { private:
return total; LZ4F_decompressionContext_t ctx;
} uint8_t *outbuf;
size_t outCapacity;
void LZ4FDecoder::read_header(const uint8_t *&in, size_t &size) { void read_header(const uint8_t *&in, size_t &size) {
size_t read = size; size_t read = size;
LZ4F_frameInfo_t info; LZ4F_frameInfo_t info;
LZ4F_getFrameInfo(ctx, &info, in, &read); LZ4F_getFrameInfo(ctx, &info, in, &read);
...@@ -366,47 +331,57 @@ void LZ4FDecoder::read_header(const uint8_t *&in, size_t &size) { ...@@ -366,47 +331,57 @@ void LZ4FDecoder::read_header(const uint8_t *&in, size_t &size) {
outbuf = new uint8_t[outCapacity]; outbuf = new uint8_t[outCapacity];
in += read; in += read;
size -= read; size -= read;
} }
};
LZ4FEncoder::LZ4FEncoder() : outbuf(nullptr), outCapacity(0), total(0) { class LZ4F_encoder : public cpr_stream {
public:
explicit LZ4F_encoder(FILE *fp) : cpr_stream(fp), outbuf(nullptr), outCapacity(0) {
LZ4F_createCompressionContext(&ctx, LZ4F_VERSION); LZ4F_createCompressionContext(&ctx, LZ4F_VERSION);
} }
LZ4FEncoder::~LZ4FEncoder() { ~LZ4F_encoder() override {
destroy();
LZ4F_freeCompressionContext(ctx); LZ4F_freeCompressionContext(ctx);
delete[] outbuf; delete[] outbuf;
} }
bool LZ4FEncoder::write(const void *in, size_t size) { int write(const void *buf, size_t len) override {
auto ret = len;
if (!outbuf) if (!outbuf)
write_header(); write_header();
if (size == 0) if (len == 0)
return true; return 0;
auto inbuf = (const uint8_t *) in; auto inbuf = reinterpret_cast<const uint8_t *>(buf);
size_t read, write; size_t read, write;
do { do {
read = size > BLOCK_SZ ? BLOCK_SZ : size; read = len > BLOCK_SZ ? BLOCK_SZ : len;
write = LZ4F_compressUpdate(ctx, outbuf, outCapacity, inbuf, read, nullptr); write = LZ4F_compressUpdate(ctx, outbuf, outCapacity, inbuf, read, nullptr);
if (LZ4F_isError(write)) { if (LZ4F_isError(write)) {
LOGW("LZ4 encode error: %s\n", LZ4F_getErrorName(write)); LOGW("LZ4F encode error: %s\n", LZ4F_getErrorName(write));
return false; return -1;
} }
size -= read; len -= read;
inbuf += read; inbuf += read;
total += write; bwrite(outbuf, write);
FilterOutStream::write(outbuf, write); } while (len != 0);
} while (size != 0); return ret;
return true; }
}
uint64_t LZ4FEncoder::finalize() { protected:
size_t write = LZ4F_compressEnd(ctx, outbuf, outCapacity, nullptr); void finish() override {
total += write; size_t len = LZ4F_compressEnd(ctx, outbuf, outCapacity, nullptr);
FilterOutStream::write(outbuf, write); bwrite(outbuf, len);
return total; }
}
private:
LZ4F_compressionContext_t ctx;
uint8_t *outbuf;
size_t outCapacity;
static constexpr size_t BLOCK_SZ = 1 << 22;
void LZ4FEncoder::write_header() { void write_header() {
LZ4F_preferences_t prefs { LZ4F_preferences_t prefs {
.autoFlush = 1, .autoFlush = 1,
.compressionLevel = 9, .compressionLevel = 9,
...@@ -420,20 +395,24 @@ void LZ4FEncoder::write_header() { ...@@ -420,20 +395,24 @@ void LZ4FEncoder::write_header() {
outCapacity = LZ4F_compressBound(BLOCK_SZ, &prefs); outCapacity = LZ4F_compressBound(BLOCK_SZ, &prefs);
outbuf = new uint8_t[outCapacity]; outbuf = new uint8_t[outCapacity];
size_t write = LZ4F_compressBegin(ctx, outbuf, outCapacity, &prefs); size_t write = LZ4F_compressBegin(ctx, outbuf, outCapacity, &prefs);
total += write; bwrite(outbuf, write);
FilterOutStream::write(outbuf, write); }
} };
LZ4Decoder::LZ4Decoder() : outbuf(new char[LZ4_UNCOMPRESSED]), buf(new char[LZ4_COMPRESSED]), class LZ4_decoder : public cpr_stream {
init(false), block_sz(0), buf_off(0), total(0) {} public:
explicit LZ4_decoder(FILE *fp)
: cpr_stream(fp), out_buf(new char[LZ4_UNCOMPRESSED]), buffer(new char[LZ4_COMPRESSED]),
init(false), block_sz(0), buf_off(0) {}
LZ4Decoder::~LZ4Decoder() { ~LZ4_decoder() override {
delete[] outbuf; delete[] out_buf;
delete[] buf; delete[] buffer;
} }
bool LZ4Decoder::write(const void *in, size_t size) { int write(const void *in, size_t size) override {
const char *inbuf = (const char *) in; auto ret = size;
auto inbuf = static_cast<const char *>(in);
if (!init) { if (!init) {
// Skip magic // Skip magic
inbuf += 4; inbuf += 4;
...@@ -449,50 +428,57 @@ bool LZ4Decoder::write(const void *in, size_t size) { ...@@ -449,50 +428,57 @@ bool LZ4Decoder::write(const void *in, size_t size) {
size -= sizeof(unsigned); size -= sizeof(unsigned);
} else if (buf_off + size >= block_sz) { } else if (buf_off + size >= block_sz) {
consumed = block_sz - buf_off; consumed = block_sz - buf_off;
memcpy(buf + buf_off, inbuf, consumed); memcpy(buffer + buf_off, inbuf, consumed);
inbuf += consumed; inbuf += consumed;
size -= consumed; size -= consumed;
write = LZ4_decompress_safe(buf, outbuf, block_sz, LZ4_UNCOMPRESSED); write = LZ4_decompress_safe(buffer, out_buf, block_sz, LZ4_UNCOMPRESSED);
if (write < 0) { if (write < 0) {
LOGW("LZ4HC decompression failure (%d)\n", write); LOGW("LZ4HC decompression failure (%d)\n", write);
return false; return -1;
} }
FilterOutStream::write(outbuf, write); bwrite(out_buf, write);
total += write;
// Reset // Reset
buf_off = 0; buf_off = 0;
block_sz = 0; block_sz = 0;
} else { } else {
// Copy to internal buffer // Copy to internal buffer
memcpy(buf + buf_off, inbuf, size); memcpy(buffer + buf_off, inbuf, size);
buf_off += size; buf_off += size;
size = 0; size = 0;
} }
} while (size != 0); } while (size != 0);
return true; return ret;
} }
uint64_t LZ4Decoder::finalize() {
return total;
}
LZ4Encoder::LZ4Encoder() : outbuf(new char[LZ4_COMPRESSED]), buf(new char[LZ4_UNCOMPRESSED]),
init(false), buf_off(0), out_total(0), in_total(0) {}
LZ4Encoder::~LZ4Encoder() { private:
char *out_buf;
char *buffer;
bool init;
unsigned block_sz;
int buf_off;
};
class LZ4_encoder : public cpr_stream {
public:
explicit LZ4_encoder(FILE *fp)
: cpr_stream(fp), outbuf(new char[LZ4_COMPRESSED]), buf(new char[LZ4_UNCOMPRESSED]),
init(false), buf_off(0), in_total(0) {}
~LZ4_encoder() override {
destroy();
delete[] outbuf; delete[] outbuf;
delete[] buf; delete[] buf;
} }
bool LZ4Encoder::write(const void *in, size_t size) { int write(const void *in, size_t size) override {
if (!init) { if (!init) {
FilterOutStream::write("\x02\x21\x4c\x18", 4); bwrite("\x02\x21\x4c\x18", 4);
init = true; init = true;
} }
if (size == 0) if (size == 0)
return true; return 0;
in_total += size; in_total += size;
const char *inbuf = (const char *) in; const char *inbuf = (const char *) in;
size_t consumed; size_t consumed;
...@@ -509,9 +495,8 @@ bool LZ4Encoder::write(const void *in, size_t size) { ...@@ -509,9 +495,8 @@ bool LZ4Encoder::write(const void *in, size_t size) {
LOGW("LZ4HC compression failure\n"); LOGW("LZ4HC compression failure\n");
return false; return false;
} }
FilterOutStream::write(&write, sizeof(write)); bwrite(&write, sizeof(write));
FilterOutStream::write(outbuf, write); bwrite(outbuf, write);
out_total += write + sizeof(write);
// Reset buffer // Reset buffer
buf_off = 0; buf_off = 0;
...@@ -523,15 +508,152 @@ bool LZ4Encoder::write(const void *in, size_t size) { ...@@ -523,15 +508,152 @@ bool LZ4Encoder::write(const void *in, size_t size) {
} }
} while (size != 0); } while (size != 0);
return true; return true;
} }
uint64_t LZ4Encoder::finalize() { protected:
void finish() override {
if (buf_off) { if (buf_off) {
int write = LZ4_compress_HC(buf, outbuf, buf_off, LZ4_COMPRESSED, 9); int write = LZ4_compress_HC(buf, outbuf, buf_off, LZ4_COMPRESSED, 9);
FilterOutStream::write(&write, sizeof(write)); bwrite(&write, sizeof(write));
FilterOutStream::write(outbuf, write); bwrite(outbuf, write);
out_total += write + sizeof(write); }
bwrite(&in_total, sizeof(in_total));
}
private:
char *outbuf;
char *buf;
bool init;
int buf_off;
unsigned in_total;
};
filter_stream *get_encoder(format_t type, FILE *fp) {
switch (type) {
case XZ:
return new xz_encoder(fp);
case LZMA:
return new lzma_encoder(fp);
case BZIP2:
return new bz_encoder(fp);
case LZ4:
return new LZ4F_encoder(fp);
case LZ4_LEGACY:
return new LZ4_encoder(fp);
case GZIP:
default:
return new gz_encoder(fp);
}
}
filter_stream *get_decoder(format_t type, FILE *fp) {
switch (type) {
case XZ:
case LZMA:
return new lzma_decoder(fp);
case BZIP2:
return new bz_decoder(fp);
case LZ4:
return new LZ4F_decoder(fp);
case LZ4_LEGACY:
return new LZ4_decoder(fp);
case GZIP:
default:
return new gz_decoder(fp);
}
}
void decompress(char *infile, const char *outfile) {
bool in_std = infile == "-"sv;
bool rm_in = false;
FILE *in_fp = in_std ? stdin : xfopen(infile, "re");
unique_ptr<stream> strm;
char buf[4096];
size_t len;
while ((len = fread(buf, 1, sizeof(buf), in_fp))) {
if (!strm) {
format_t type = check_fmt(buf, len);
if (!COMPRESSED(type))
LOGE("Input file is not a supported compressed type!\n");
fprintf(stderr, "Detected format: [%s]\n", fmt2name[type]);
/* If user does not provide outfile, infile has to be either
* <path>.[ext], or '-'. Outfile will be either <path> or '-'.
* If the input does not have proper format, abort */
char *ext = nullptr;
if (outfile == nullptr) {
outfile = infile;
if (!in_std) {
ext = strrchr(infile, '.');
if (ext == nullptr || strcmp(ext, fmt2ext[type]) != 0)
LOGE("Input file is not a supported type!\n");
// Strip out extension and remove input
*ext = '\0';
rm_in = true;
fprintf(stderr, "Decompressing to [%s]\n", outfile);
}
}
FILE *out_fp = outfile == "-"sv ? stdout : xfopen(outfile, "we");
strm.reset(get_decoder(type, out_fp));
if (ext) *ext = '.';
} }
FilterOutStream::write(&in_total, sizeof(in_total)); if (strm->write(buf, len) < 0)
return out_total + sizeof(in_total); LOGE("Decompression error!\n");
}
strm->close();
fclose(in_fp);
if (rm_in)
unlink(infile);
}
void compress(const char *method, const char *infile, const char *outfile) {
auto it = name2fmt.find(method);
if (it == name2fmt.end())
LOGE("Unknown compression method: [%s]\n", method);
bool in_std = infile == "-"sv;
bool rm_in = false;
FILE *in_fp = in_std ? stdin : xfopen(infile, "re");
FILE *out_fp;
if (outfile == nullptr) {
if (in_std) {
out_fp = stdout;
} else {
/* If user does not provide outfile and infile is not
* STDIN, output to <infile>.[ext] */
string tmp(infile);
tmp += fmt2ext[it->second];
out_fp = xfopen(tmp.data(), "we");
fprintf(stderr, "Compressing to [%s]\n", tmp.data());
rm_in = true;
}
} else {
out_fp = outfile == "-"sv ? stdout : xfopen(outfile, "we");
}
unique_ptr<stream> strm(get_encoder(it->second, out_fp));
char buf[4096];
size_t len;
while ((len = fread(buf, 1, sizeof(buf), in_fp))) {
if (strm->write(buf, len) < 0)
LOGE("Compression error!\n");
};
strm->close();
fclose(in_fp);
if (rm_in)
unlink(infile);
} }
#pragma once #pragma once
#include <zlib.h>
#include <bzlib.h>
#include <lzma.h>
#include <lz4.h>
#include <lz4frame.h>
#include <lz4hc.h>
#include <stream.h> #include <stream.h>
#include "format.h" #include "format.h"
#define CHUNK 0x40000 filter_stream *get_encoder(format_t type, FILE *fp = nullptr);
filter_stream *get_decoder(format_t type, FILE *fp = nullptr);
class Compression : public FilterOutStream {
public:
virtual uint64_t finalize() = 0;
};
class GZStream : public Compression {
public:
bool write(const void *in, size_t size) override;
uint64_t finalize() override;
protected:
explicit GZStream(int mode);
private:
int mode;
z_stream strm;
uint8_t outbuf[CHUNK];
bool write(const void *in, size_t size, int flush);
};
class GZDecoder : public GZStream {
public:
GZDecoder() : GZStream(0) {};
};
class GZEncoder : public GZStream {
public:
GZEncoder() : GZStream(1) {};
};
class BZStream : public Compression {
public:
bool write(const void *in, size_t size) override;
uint64_t finalize() override;
protected:
explicit BZStream(int mode);
private:
int mode;
bz_stream strm;
char outbuf[CHUNK];
bool write(const void *in, size_t size, int flush);
};
class BZDecoder : public BZStream {
public:
BZDecoder() : BZStream(0) {};
};
class BZEncoder : public BZStream {
public:
BZEncoder() : BZStream(1) {};
};
class LZMAStream : public Compression {
public:
bool write(const void *in, size_t size) override;
uint64_t finalize() override;
protected:
explicit LZMAStream(int mode);
private:
int mode;
lzma_stream strm;
uint8_t outbuf[CHUNK];
bool write(const void *in, size_t size, lzma_action flush);
};
class LZMADecoder : public LZMAStream {
public:
LZMADecoder() : LZMAStream(0) {}
};
class XZEncoder : public LZMAStream {
public:
XZEncoder() : LZMAStream(1) {}
};
class LZMAEncoder : public LZMAStream {
public:
LZMAEncoder() : LZMAStream(2) {}
};
class LZ4FDecoder : public Compression {
public:
LZ4FDecoder();
~LZ4FDecoder() override;
bool write(const void *in, size_t size) override;
uint64_t finalize() override;
private:
LZ4F_decompressionContext_t ctx;
uint8_t *outbuf;
size_t outCapacity;
uint64_t total;
void read_header(const uint8_t *&in, size_t &size);
};
class LZ4FEncoder : public Compression {
public:
LZ4FEncoder();
~LZ4FEncoder() override;
bool write(const void *in, size_t size) override;
uint64_t finalize() override;
private:
static constexpr size_t BLOCK_SZ = 1 << 22;
LZ4F_compressionContext_t ctx;
uint8_t *outbuf;
size_t outCapacity;
uint64_t total;
void write_header();
};
#define LZ4_UNCOMPRESSED 0x800000
#define LZ4_COMPRESSED LZ4_COMPRESSBOUND(LZ4_UNCOMPRESSED)
class LZ4Decoder : public Compression {
public:
LZ4Decoder();
~LZ4Decoder() override;
bool write(const void *in, size_t size) override;
uint64_t finalize() override;
private:
char *outbuf;
char *buf;
bool init;
unsigned block_sz;
int buf_off;
uint64_t total;
};
class LZ4Encoder : public Compression {
public:
LZ4Encoder();
~LZ4Encoder() override;
bool write(const void *in, size_t size) override;
uint64_t finalize() override;
private:
char *outbuf;
char *buf;
bool init;
int buf_off;
uint64_t out_total;
unsigned in_total;
};
Compression *get_encoder(format_t type);
Compression *get_decoder(format_t type);
void compress(const char *method, const char *infile, const char *outfile); void compress(const char *method, const char *infile, const char *outfile);
void decompress(char *infile, const char *outfile); void decompress(char *infile, const char *outfile);
...@@ -241,14 +241,17 @@ void magisk_cpio::compress() { ...@@ -241,14 +241,17 @@ void magisk_cpio::compress() {
return; return;
fprintf(stderr, "Compressing cpio -> [%s]\n", RAMDISK_XZ); fprintf(stderr, "Compressing cpio -> [%s]\n", RAMDISK_XZ);
auto init = entries.extract("init"); auto init = entries.extract("init");
XZEncoder encoder;
encoder.setOut(make_unique<BufOutStream>()); uint8_t *data;
output(encoder); size_t len;
encoder.finalize(); FILE *fp = open_stream(get_encoder(XZ, open_stream<byte_stream>(data, len)));
dump(fp);
entries.clear(); entries.clear();
entries.insert(std::move(init)); entries.insert(std::move(init));
auto xz = new cpio_entry(RAMDISK_XZ, S_IFREG); auto xz = new cpio_entry(RAMDISK_XZ, S_IFREG);
static_cast<BufOutStream *>(encoder.getOut())->release(xz->data, xz->filesize); xz->data = data;
xz->filesize = len;
insert(xz); insert(xz);
} }
...@@ -257,15 +260,16 @@ void magisk_cpio::decompress() { ...@@ -257,15 +260,16 @@ void magisk_cpio::decompress() {
if (it == entries.end()) if (it == entries.end())
return; return;
fprintf(stderr, "Decompressing cpio [%s]\n", RAMDISK_XZ); fprintf(stderr, "Decompressing cpio [%s]\n", RAMDISK_XZ);
LZMADecoder decoder;
decoder.setOut(make_unique<BufOutStream>()); char *data;
decoder.write(it->second->data, it->second->filesize); size_t len;
decoder.finalize(); auto strm = get_decoder(XZ, open_stream<byte_stream>(data, len));
strm->write(it->second->data, it->second->filesize);
delete strm;
entries.erase(it); entries.erase(it);
char *buf; load_cpio(data, len);
size_t sz; free(data);
static_cast<BufOutStream *>(decoder.getOut())->getbuf(buf, sz);
load_cpio(buf, sz);
} }
int cpio_commands(int argc, char *argv[]) { int cpio_commands(int argc, char *argv[]) {
......
...@@ -49,8 +49,7 @@ cpio_entry_base::cpio_entry_base(const cpio_newc_header *h) ...@@ -49,8 +49,7 @@ cpio_entry_base::cpio_entry_base(const cpio_newc_header *h)
void cpio::dump(const char *file) { void cpio::dump(const char *file) {
fprintf(stderr, "Dump cpio: [%s]\n", file); fprintf(stderr, "Dump cpio: [%s]\n", file);
FDOutStream fd_out(xopen(file, O_WRONLY | O_CREAT | O_TRUNC, 0644), true); dump(xfopen(file, "we"));
output(fd_out);
} }
void cpio::rm(entry_map::iterator &it) { void cpio::rm(entry_map::iterator &it) {
...@@ -110,9 +109,9 @@ bool cpio::exists(const char *name) { ...@@ -110,9 +109,9 @@ bool cpio::exists(const char *name) {
return entries.count(name) != 0; return entries.count(name) != 0;
} }
#define do_out(b, l) out.write(b, l); pos += (l) #define do_out(buf, len) pos += fwrite(buf, len, 1, out);
#define out_align() out.write(zeros, align_off(pos, 4)); pos = do_align(pos, 4) #define out_align() do_out(zeros, align_off(pos, 4))
void cpio::output(OutStream &out) { void cpio::dump(FILE *out) {
size_t pos = 0; size_t pos = 0;
unsigned inode = 300000; unsigned inode = 300000;
char header[111]; char header[111];
...@@ -147,6 +146,7 @@ void cpio::output(OutStream &out) { ...@@ -147,6 +146,7 @@ void cpio::output(OutStream &out) {
do_out(header, 110); do_out(header, 110);
do_out("TRAILER!!!\0", 11); do_out("TRAILER!!!\0", 11);
out_align(); out_align();
fclose(out);
} }
cpio_rw::cpio_rw(const char *file) { cpio_rw::cpio_rw(const char *file) {
...@@ -221,12 +221,12 @@ bool cpio_rw::mv(const char *from, const char *to) { ...@@ -221,12 +221,12 @@ bool cpio_rw::mv(const char *from, const char *to) {
#define pos_align(p) p = do_align(p, 4) #define pos_align(p) p = do_align(p, 4)
void cpio_rw::load_cpio(char *buf, size_t sz) { void cpio_rw::load_cpio(const char *buf, size_t sz) {
size_t pos = 0; size_t pos = 0;
cpio_newc_header *header; const cpio_newc_header *header;
unique_ptr<cpio_entry> entry; unique_ptr<cpio_entry> entry;
while (pos < sz) { while (pos < sz) {
header = (cpio_newc_header *)(buf + pos); header = reinterpret_cast<const cpio_newc_header *>(buf + pos);
entry = make_unique<cpio_entry>(header); entry = make_unique<cpio_entry>(header);
pos += sizeof(*header); pos += sizeof(*header);
string_view name_view(buf + pos); string_view name_view(buf + pos);
......
...@@ -30,7 +30,7 @@ struct cpio_entry : public cpio_entry_base { ...@@ -30,7 +30,7 @@ struct cpio_entry : public cpio_entry_base {
explicit cpio_entry(const char *name, uint32_t mode) : filename(name) { explicit cpio_entry(const char *name, uint32_t mode) : filename(name) {
this->mode = mode; this->mode = mode;
} }
explicit cpio_entry(cpio_newc_header *h) : cpio_entry_base(h) {} explicit cpio_entry(const cpio_newc_header *h) : cpio_entry_base(h) {}
~cpio_entry() override { free(data); }; ~cpio_entry() override { free(data); };
}; };
...@@ -48,7 +48,7 @@ public: ...@@ -48,7 +48,7 @@ public:
protected: protected:
entry_map entries; entry_map entries;
void rm(entry_map::iterator &it); void rm(entry_map::iterator &it);
void output(OutStream &out); void dump(FILE *out);
}; };
class cpio_rw : public cpio { class cpio_rw : public cpio {
...@@ -64,7 +64,7 @@ public: ...@@ -64,7 +64,7 @@ public:
protected: protected:
void insert(cpio_entry *e); void insert(cpio_entry *e);
void mv(entry_map::iterator &it, const char *to); void mv(entry_map::iterator &it, const char *to);
void load_cpio(char *buf, size_t sz); void load_cpio(const char *buf, size_t sz);
}; };
class cpio_mmap : public cpio { class cpio_mmap : public cpio {
......
...@@ -15,8 +15,6 @@ FILE *open_stream(Args &&... args) { ...@@ -15,8 +15,6 @@ FILE *open_stream(Args &&... args) {
return open_stream(new T(args...)); return open_stream(new T(args...));
} }
/* Base classes */
class stream { class stream {
public: public:
virtual int read(void *buf, size_t len); virtual int read(void *buf, size_t len);
...@@ -26,17 +24,17 @@ public: ...@@ -26,17 +24,17 @@ public:
virtual ~stream() = default; virtual ~stream() = default;
}; };
// Delegates all operations to the base FILE pointer
class filter_stream : public stream { class filter_stream : public stream {
public: public:
filter_stream(FILE *fp) : fp(fp) {} filter_stream(FILE *fp) : fp(fp) {}
int close() override { return fclose(fp); } ~filter_stream() override { if (fp) close(); }
virtual ~filter_stream() { close(); }
void set_base(FILE *f) { int read(void *buf, size_t len) override;
if (fp) fclose(fp); int write(const void *buf, size_t len) override;
fp = f; int close() override;
}
void set_base(FILE *f);
template <class T, class... Args > template <class T, class... Args >
void set_base(Args&&... args) { void set_base(Args&&... args) {
set_base(open_stream<T>(args...)); set_base(open_stream<T>(args...));
...@@ -46,18 +44,7 @@ protected: ...@@ -46,18 +44,7 @@ protected:
FILE *fp; FILE *fp;
}; };
class filter_in_stream : public filter_stream { // Handy interface for classes that need custom seek logic
public:
filter_in_stream(FILE *fp = nullptr) : filter_stream(fp) {}
int read(void *buf, size_t len) override { return fread(buf, len, 1, fp); }
};
class filter_out_stream : public filter_stream {
public:
filter_out_stream(FILE *fp = nullptr) : filter_stream(fp) {}
int write(const void *buf, size_t len) override { return fwrite(buf, len, 1, fp); }
};
class seekable_stream : public stream { class seekable_stream : public stream {
protected: protected:
size_t _pos = 0; size_t _pos = 0;
...@@ -66,8 +53,7 @@ protected: ...@@ -66,8 +53,7 @@ protected:
virtual size_t end_pos() = 0; virtual size_t end_pos() = 0;
}; };
/* Concrete classes */ // Byte stream that dynamically allocates memory
class byte_stream : public seekable_stream { class byte_stream : public seekable_stream {
public: public:
byte_stream(uint8_t *&buf, size_t &len); byte_stream(uint8_t *&buf, size_t &len);
...@@ -76,7 +62,6 @@ public: ...@@ -76,7 +62,6 @@ public:
int read(void *buf, size_t len) override; int read(void *buf, size_t len) override;
int write(const void *buf, size_t len) override; int write(const void *buf, size_t len) override;
off_t seek(off_t off, int whence) override; off_t seek(off_t off, int whence) override;
virtual ~byte_stream() = default;
private: private:
uint8_t *&_buf; uint8_t *&_buf;
...@@ -87,101 +72,14 @@ private: ...@@ -87,101 +72,14 @@ private:
size_t end_pos() override { return _len; } size_t end_pos() override { return _len; }
}; };
class fd_stream : stream { // File stream but does not close the file descriptor at any time
class fd_stream : public stream {
public: public:
fd_stream(int fd) : fd(fd) {} fd_stream(int fd) : fd(fd) {}
int read(void *buf, size_t len) override; int read(void *buf, size_t len) override;
int write(const void *buf, size_t len) override; int write(const void *buf, size_t len) override;
off_t seek(off_t off, int whence) override; off_t seek(off_t off, int whence) override;
virtual ~fd_stream() = default;
private: private:
int fd; int fd;
}; };
/* TODO: Replace classes below to new implementation */
class OutStream {
public:
virtual bool write(const void *buf, size_t len) = 0;
virtual ~OutStream() = default;
};
typedef std::unique_ptr<OutStream> strm_ptr;
class FilterOutStream : public OutStream {
public:
FilterOutStream() = default;
FilterOutStream(strm_ptr &&ptr) : out(std::move(ptr)) {}
void setOut(strm_ptr &&ptr) { out = std::move(ptr); }
OutStream *getOut() { return out.get(); }
bool write(const void *buf, size_t len) override {
return out ? out->write(buf, len) : false;
}
protected:
strm_ptr out;
};
class FDOutStream : public OutStream {
public:
FDOutStream(int fd, bool close = false) : fd(fd), close(close) {}
bool write(const void *buf, size_t len) override {
return ::write(fd, buf, len) == len;
}
~FDOutStream() override {
if (close)
::close(fd);
}
protected:
int fd;
bool close;
};
class BufOutStream : public OutStream {
public:
BufOutStream() : buf(nullptr), off(0), cap(0) {};
bool write(const void *b, size_t len) override {
bool resize = false;
while (off + len > cap) {
cap = cap ? cap << 1 : 1 << 19;
resize = true;
}
if (resize)
buf = (char *) xrealloc(buf, cap);
memcpy(buf + off, b, len);
off += len;
return true;
}
template <typename bytes, typename length>
void release(bytes *&b, length &len) {
b = buf;
len = off;
buf = nullptr;
off = cap = 0;
}
template <typename bytes, typename length>
void getbuf(bytes *&b, length &len) const {
b = buf;
len = off;
}
~BufOutStream() override {
free(buf);
}
protected:
char *buf;
size_t off;
size_t cap;
};
...@@ -49,6 +49,25 @@ int stream::close() { ...@@ -49,6 +49,25 @@ int stream::close() {
return 0; return 0;
} }
int filter_stream::read(void *buf, size_t len) {
return fread(buf, len, 1, fp);
}
int filter_stream::write(const void *buf, size_t len) {
return fwrite(buf, len, 1, fp);
}
int filter_stream::close() {
int ret = fclose(fp);
fp = nullptr;
return ret;
}
void filter_stream::set_base(FILE *f) {
if (fp) fclose(fp);
fp = f;
}
off_t seekable_stream::new_pos(off_t off, int whence) { off_t seekable_stream::new_pos(off_t off, int whence) {
off_t new_pos; off_t new_pos;
switch (whence) { switch (whence) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment