diff --git a/bufferpool.go b/bufferpool.go new file mode 100644 index 0000000..b3d7000 --- /dev/null +++ b/bufferpool.go @@ -0,0 +1,30 @@ +package gozstd + +import ( + "bytes" + "sync" +) + +var compInBufPool = sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(make([]byte, 0, cstreamInBufSize)) + }, +} + +var compOutBufPool = sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(make([]byte, 0, cstreamOutBufSize)) + }, +} + +var decInBufPool = sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(make([]byte, 0, dstreamInBufSize)) + }, +} + +var decOutBufPool = sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(make([]byte, 0, dstreamOutBufSize)) + }, +} diff --git a/go.mod b/go.mod index f422410..49a61b1 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/valyala/gozstd -go 1.12 +go 1.20 diff --git a/gozstd.go b/gozstd.go index 831b703..a7abbad 100644 --- a/gozstd.go +++ b/gozstd.go @@ -7,29 +7,27 @@ package gozstd #include "zstd.h" #include "zstd_errors.h" -#include // for uintptr_t - // The following *_wrapper functions allow avoiding memory allocations // durting calls from Go. // See https://github.com/golang/go/issues/24450 . -static size_t ZSTD_compressCCtx_wrapper(uintptr_t ctx, uintptr_t dst, size_t dstCapacity, uintptr_t src, size_t srcSize, int compressionLevel) { - return ZSTD_compressCCtx((ZSTD_CCtx*)ctx, (void*)dst, dstCapacity, (const void*)src, srcSize, compressionLevel); +static size_t ZSTD_compressCCtx_wrapper(void *ctx, void *dst, size_t dstCapacity, const void *src, size_t srcSize, int compressionLevel) { + return ZSTD_compressCCtx((ZSTD_CCtx*)ctx, dst, dstCapacity, src, srcSize, compressionLevel); } -static size_t ZSTD_compress_usingCDict_wrapper(uintptr_t ctx, uintptr_t dst, size_t dstCapacity, uintptr_t src, size_t srcSize, uintptr_t cdict) { +static size_t ZSTD_compress_usingCDict_wrapper(void *ctx, void *dst, size_t dstCapacity, void *src, size_t srcSize, void *cdict) { return ZSTD_compress_usingCDict((ZSTD_CCtx*)ctx, (void*)dst, dstCapacity, (const void*)src, srcSize, (const ZSTD_CDict*)cdict); } -static size_t ZSTD_decompressDCtx_wrapper(uintptr_t ctx, uintptr_t dst, size_t dstCapacity, uintptr_t src, size_t srcSize) { +static size_t ZSTD_decompressDCtx_wrapper(void *ctx, void *dst, size_t dstCapacity, void *src, size_t srcSize) { return ZSTD_decompressDCtx((ZSTD_DCtx*)ctx, (void*)dst, dstCapacity, (const void*)src, srcSize); } -static size_t ZSTD_decompress_usingDDict_wrapper(uintptr_t ctx, uintptr_t dst, size_t dstCapacity, uintptr_t src, size_t srcSize, uintptr_t ddict) { +static size_t ZSTD_decompress_usingDDict_wrapper(void *ctx, void *dst, size_t dstCapacity, void *src, size_t srcSize, void *ddict) { return ZSTD_decompress_usingDDict((ZSTD_DCtx*)ctx, (void*)dst, dstCapacity, (const void*)src, srcSize, (const ZSTD_DDict*)ddict); } -static unsigned long long ZSTD_findDecompressedSize_wrapper(uintptr_t src, size_t srcSize) { +static unsigned long long ZSTD_findDecompressedSize_wrapper(void *src, size_t srcSize) { return ZSTD_findDecompressedSize((const void*)src, srcSize); } @@ -47,6 +45,8 @@ import ( // DefaultCompressionLevel is the default compression level. const DefaultCompressionLevel = 3 // Obtained from ZSTD_CLEVEL_DEFAULT. +const maxFrameContentSize = 256 << 20 // 256 MB + // Compress appends compressed src to dst and returns the result. func Compress(dst, src []byte) []byte { return compressDictLevel(dst, src, nil, DefaultCompressionLevel) @@ -147,36 +147,55 @@ func compress(cctx, cctxDict *cctxWrapper, dst, src []byte, cd *CDict, compressi return dst } +// noescape hides a pointer from escape analysis. It is the identity function +// but escape analysis doesn't think the output depends on the input. +// noescape is inlined and currently compiles down to zero instructions. +// This is copied from go's strings.Builder. Allows us to use stack-allocated +// slices. +// +//go:nosplit +//go:nocheckptr +func noescape(p unsafe.Pointer) unsafe.Pointer { + x := uintptr(p) + return unsafe.Pointer(x ^ 0) +} + func compressInternal(cctx, cctxDict *cctxWrapper, dst, src []byte, cd *CDict, compressionLevel int, mustSucceed bool) C.size_t { + // using noescape will allow this to work with stack-allocated slices + dstPtr := noescape(unsafe.Pointer(unsafe.SliceData(dst))) + srcPtr := noescape(unsafe.Pointer(unsafe.SliceData(src))) + if cd != nil { result := C.ZSTD_compress_usingCDict_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(cctxDict.cctx))), - C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))), + unsafe.Pointer(cctxDict.cctx), + dstPtr, C.size_t(cap(dst)), - C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))), + srcPtr, C.size_t(len(src)), - C.uintptr_t(uintptr(unsafe.Pointer(cd.p)))) + unsafe.Pointer(cd.p)) // Prevent from GC'ing of dst and src during CGO call above. runtime.KeepAlive(dst) runtime.KeepAlive(src) if mustSucceed { - ensureNoError("ZSTD_compress_usingCDict_wrapper", result) + ensureNoError("ZSTD_compress_usingCDict", result) } return result } + result := C.ZSTD_compressCCtx_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(cctx.cctx))), - C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))), + unsafe.Pointer(cctx.cctx), + dstPtr, C.size_t(cap(dst)), - C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))), + srcPtr, C.size_t(len(src)), C.int(compressionLevel)) // Prevent from GC'ing of dst and src during CGO call above. runtime.KeepAlive(dst) runtime.KeepAlive(src) if mustSucceed { - ensureNoError("ZSTD_compressCCtx_wrapper", result) + ensureNoError("ZSTD_compressCCtx", result) } + return result } @@ -255,17 +274,15 @@ func decompress(dctx, dctxDict *dctxWrapper, dst, src []byte, dd *DDict) ([]byte } // Slow path - resize dst to fit decompressed data. - decompressBound := int(C.ZSTD_findDecompressedSize_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))), C.size_t(len(src)))) - // Prevent from GC'ing of src during CGO call above. - runtime.KeepAlive(src) - switch uint64(decompressBound) { - case uint64(C.ZSTD_CONTENTSIZE_UNKNOWN): + srcPtr := noescape(unsafe.Pointer(unsafe.SliceData(src))) + contentSize := C.ZSTD_findDecompressedSize_wrapper(srcPtr, C.size_t(len(src))) + switch { + case contentSize == C.ZSTD_CONTENTSIZE_UNKNOWN || contentSize > maxFrameContentSize: return streamDecompress(dst, src, dd) - case uint64(C.ZSTD_CONTENTSIZE_ERROR): + case contentSize == C.ZSTD_CONTENTSIZE_ERROR: return dst, fmt.Errorf("cannot decompress invalid src") } - decompressBound++ + decompressBound := int(contentSize) + 1 if n := dstLen + decompressBound - cap(dst); n > 0 { // This should be optimized since go 1.11 - see https://golang.org/doc/go1.11#performance-compiler. @@ -288,24 +305,28 @@ func decompress(dctx, dctxDict *dctxWrapper, dst, src []byte, dd *DDict) ([]byte } func decompressInternal(dctx, dctxDict *dctxWrapper, dst, src []byte, dd *DDict) C.size_t { - var n C.size_t + var ( + dstPtr = noescape(unsafe.Pointer(unsafe.SliceData(dst))) + srcPtr = noescape(unsafe.Pointer(unsafe.SliceData(src))) + n C.size_t + ) if dd != nil { n = C.ZSTD_decompress_usingDDict_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(dctxDict.dctx))), - C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))), + unsafe.Pointer(dctxDict.dctx), + dstPtr, C.size_t(cap(dst)), - C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))), + srcPtr, C.size_t(len(src)), - C.uintptr_t(uintptr(unsafe.Pointer(dd.p)))) + unsafe.Pointer(dd.p)) } else { n = C.ZSTD_decompressDCtx_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(dctx.dctx))), - C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))), + unsafe.Pointer(dctx.dctx), + dstPtr, C.size_t(cap(dst)), - C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))), + srcPtr, C.size_t(len(src))) } - // Prevent from GC'ing of dst and src during CGO calls above. + // Prevent from GC'ing of dst and src during CGO call above. runtime.KeepAlive(dst) runtime.KeepAlive(src) return n @@ -318,13 +339,17 @@ func errStr(result C.size_t) string { } func ensureNoError(funcName string, result C.size_t) { + if zstdIsError(result) { + panic(fmt.Errorf("BUG: unexpected error in %s: %s", funcName, errStr(result))) + } +} + +func zstdIsError(result C.size_t) bool { if int(result) >= 0 { // Fast path - avoid calling C function. - return - } - if C.ZSTD_getErrorCode(result) != 0 { - panic(fmt.Errorf("BUG: unexpected error in %s: %s", funcName, errStr(result))) + return false } + return C.ZSTD_isError(result) != 0 } func streamDecompress(dst, src []byte, dd *DDict) ([]byte, error) { diff --git a/gozstd_example_test.go b/gozstd_example_test.go index 0f43be9..91fdfb0 100644 --- a/gozstd_example_test.go +++ b/gozstd_example_test.go @@ -21,18 +21,18 @@ func ExampleCompress_simple() { } func ExampleCompress_multiple() { - data := []byte("foo bar baz") - - // Compress and decompress data into new buffers. - compressedData := Compress(nil, data) - decompressedData, err := Decompress(nil, append(compressedData, compressedData...)) - if err != nil { - log.Fatalf("cannot decompress data: %s", err) - } - - fmt.Printf("%s", decompressedData) - // Output: - // foo bar bazfoo bar baz + data := []byte("foo bar baz") + + // Compress and decompress data into new buffers. + compressedData := Compress(nil, data) + decompressedData, err := Decompress(nil, append(compressedData, compressedData...)) + if err != nil { + log.Fatalf("cannot decompress data: %s", err) + } + + fmt.Printf("%s", decompressedData) + // Output: + // foo bar bazfoo bar baz } func ExampleDecompress_simple() { diff --git a/gozstd_test.go b/gozstd_test.go index 3a465b8..7729b66 100644 --- a/gozstd_test.go +++ b/gozstd_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/hex" "fmt" + "io" "math/rand" "runtime" "strings" @@ -54,6 +55,22 @@ func TestDecompressSmallBlockWithoutSingleSegmentFlag(t *testing.T) { }) } +func TestCompressEmpty(t *testing.T) { + var dst [64]byte + res := Compress(dst[:0], nil) + if len(res) > 0 { + t.Fatalf("unexpected non-empty compressed frame: %X", res) + } +} + +func TestDecompressTooLarge(t *testing.T) { + src := []byte{40, 181, 47, 253, 228, 122, 118, 105, 67, 140, 234, 85, 20, 159, 67} + _, err := Decompress(nil, src) + if err == nil { + t.Fatalf("expecting error when decompressing malformed frame") + } +} + func mustUnhex(dataHex string) []byte { data, err := hex.DecodeString(dataHex) if err != nil { @@ -62,6 +79,48 @@ func mustUnhex(dataHex string) []byte { return data } +func TestCompressWithStackMove(t *testing.T) { + var srcBuf [96]byte + + n, err := io.ReadFull(rand.New(rand.NewSource(time.Now().Unix())), srcBuf[:]) + if err != nil { + t.Fatalf("cannot fill srcBuf with random data: %s", err) + } + + // We're running this twice, because the first run will allocate + // objects in sync.Pool, calls to which extend the stack, and the second + // run can skip those allocations and extend the stack right before + // the CGO call. + // Note that this test might require some go:nosplit annotations + // to force the stack move to happen exactly before the CGO call. + for i := 0; i < 2; i++ { + ch := make(chan struct{}) + go func() { + defer close(ch) + + var dstBuf [1416]byte + + res := Compress(dstBuf[:0], srcBuf[:n]) + + // make a copy of the result, so the original can remain on the stack + compressedCpy := make([]byte, len(res)) + copy(compressedCpy, res) + + orig, err := Decompress(nil, compressedCpy) + if err != nil { + panic(fmt.Errorf("cannot decompress: %s", err)) + } + if !bytes.Equal(orig, srcBuf[:n]) { + panic(fmt.Errorf("unexpected decompressed data; got %q; want %q", orig, srcBuf[:n])) + } + }() + // wait for the goroutine to finish + <-ch + } + + runtime.GC() +} + func TestCompressDecompressDistinctConcurrentDicts(t *testing.T) { // Build multiple distinct dicts. var cdicts []*CDict diff --git a/reader.go b/reader.go index 9ebcb96..d656a22 100644 --- a/reader.go +++ b/reader.go @@ -7,14 +7,18 @@ package gozstd #include "zstd.h" #include "zstd_errors.h" -#include // for malloc/free -#include // for uintptr_t +typedef struct { + size_t dstSize; + size_t srcSize; + size_t dstPos; + size_t srcPos; +} ZSTD_EXT_BufferSizes; // The following *_wrapper functions allow avoiding memory allocations // durting calls from Go. // See https://github.com/golang/go/issues/24450 . -static size_t ZSTD_initDStream_usingDDict_wrapper(uintptr_t ds, uintptr_t dict) { +static size_t ZSTD_initDStream_usingDDict_wrapper(void *ds, void *dict) { ZSTD_DStream *zds = (ZSTD_DStream *)ds; size_t rv = ZSTD_DCtx_reset(zds, ZSTD_reset_session_only); if (rv != 0) { @@ -23,23 +27,26 @@ static size_t ZSTD_initDStream_usingDDict_wrapper(uintptr_t ds, uintptr_t dict) return ZSTD_DCtx_refDDict(zds, (ZSTD_DDict *)dict); } -static size_t ZSTD_freeDStream_wrapper(uintptr_t ds) { +static size_t ZSTD_freeDStream_wrapper(void *ds) { return ZSTD_freeDStream((ZSTD_DStream*)ds); } -static size_t ZSTD_decompressStream_wrapper(uintptr_t ds, uintptr_t output, uintptr_t input) { - return ZSTD_decompressStream((ZSTD_DStream*)ds, (ZSTD_outBuffer*)output, (ZSTD_inBuffer*)input); +static size_t ZSTD_decompressStream_wrapper(void *ds, void* dst, const void* src, ZSTD_EXT_BufferSizes* sizes) { + return ZSTD_decompressStream_simpleArgs((ZSTD_DStream*)ds, dst, sizes->dstSize, &sizes->dstPos, src, sizes->srcSize, &sizes->srcPos); } */ import "C" import ( + "bytes" "fmt" "io" "runtime" "unsafe" ) +const minDirectWriteBufferSize = 16 * 1024 + var ( dstreamInBufSize = C.ZSTD_DStreamInSize() dstreamOutBufSize = C.ZSTD_DStreamOutSize() @@ -51,11 +58,17 @@ type Reader struct { ds *C.ZSTD_DStream dd *DDict - inBuf *C.ZSTD_inBuffer - outBuf *C.ZSTD_outBuffer + inBufWrapper *bytes.Buffer + outBufWrapper *bytes.Buffer + + skipNextRead bool - inBufGo cMemPtr - outBufGo cMemPtr + readerPos int + inBuf []byte + outBuf []byte + // go doesn't allow passing pointers to structs with pointers to Go memory + // so we can't use ZSTD_inBuffer and ZSTD_outBuffer directly + sizes C.ZSTD_EXT_BufferSizes } // NewReader returns new zstd reader reading compressed data from r. @@ -73,37 +86,29 @@ func NewReaderDict(r io.Reader, dd *DDict) *Reader { ds := C.ZSTD_createDStream() initDStream(ds, dd) - inBuf := (*C.ZSTD_inBuffer)(C.calloc(1, C.sizeof_ZSTD_inBuffer)) - inBuf.src = C.calloc(1, dstreamInBufSize) - inBuf.size = 0 - inBuf.pos = 0 - - outBuf := (*C.ZSTD_outBuffer)(C.calloc(1, C.sizeof_ZSTD_outBuffer)) - outBuf.dst = C.calloc(1, dstreamOutBufSize) - outBuf.size = 0 - outBuf.pos = 0 + inBufWrapper := decInBufPool.Get().(*bytes.Buffer) + outBufWrapper := decOutBufPool.Get().(*bytes.Buffer) zr := &Reader{ - r: r, - ds: ds, - dd: dd, - inBuf: inBuf, - outBuf: outBuf, + r: r, + ds: ds, + dd: dd, + inBufWrapper: inBufWrapper, + outBufWrapper: outBufWrapper, + inBuf: inBufWrapper.Bytes(), + outBuf: outBufWrapper.Bytes(), } - zr.inBufGo = cMemPtr(zr.inBuf.src) - zr.outBufGo = cMemPtr(zr.outBuf.dst) - runtime.SetFinalizer(zr, freeDStream) return zr } // Reset resets zr to read from r using the given dictionary dd. func (zr *Reader) Reset(r io.Reader, dd *DDict) { - zr.inBuf.size = 0 - zr.inBuf.pos = 0 - zr.outBuf.size = 0 - zr.outBuf.pos = 0 + zr.readerPos = 0 + zr.sizes = C.ZSTD_EXT_BufferSizes{} + zr.inBuf = zr.inBuf[:0] + zr.outBuf = zr.outBuf[:0] zr.dd = dd initDStream(zr.ds, zr.dd) @@ -116,9 +121,7 @@ func initDStream(ds *C.ZSTD_DStream, dd *DDict) { if dd != nil { ddict = dd.p } - result := C.ZSTD_initDStream_usingDDict_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(ds))), - C.uintptr_t(uintptr(unsafe.Pointer(ddict)))) + result := C.ZSTD_initDStream_usingDDict_wrapper(unsafe.Pointer(ds), unsafe.Pointer(ddict)) ensureNoError("ZSTD_initDStream_usingDDict", result) } @@ -134,21 +137,23 @@ func (zr *Reader) Release() { return } - result := C.ZSTD_freeDStream_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(zr.ds)))) + result := C.ZSTD_freeDStream_wrapper(unsafe.Pointer(zr.ds)) ensureNoError("ZSTD_freeDStream", result) zr.ds = nil - C.free(zr.inBuf.src) - C.free(unsafe.Pointer(zr.inBuf)) - zr.inBuf = nil - - C.free(zr.outBuf.dst) - C.free(unsafe.Pointer(zr.outBuf)) - zr.outBuf = nil - zr.r = nil zr.dd = nil + + if zr.inBuf != nil { + zr.inBuf = nil + decInBufPool.Put(zr.inBufWrapper) + zr.inBufWrapper = nil + } + if zr.outBuf != nil { + zr.outBuf = nil + decOutBufPool.Put(zr.outBufWrapper) + zr.outBufWrapper = nil + } } // WriteTo writes all the data from zr to w. @@ -157,16 +162,17 @@ func (zr *Reader) Release() { func (zr *Reader) WriteTo(w io.Writer) (int64, error) { nn := int64(0) for { - if zr.outBuf.pos == zr.outBuf.size { - if err := zr.fillOutBuf(); err != nil { + if zr.readerPos >= len(zr.outBuf) { + if _, err := zr.fillOutBuf(nil); err != nil { if err == io.EOF { return nn, nil } return nn, err } + zr.readerPos = 0 } - n, err := w.Write(zr.outBufGo[zr.outBuf.pos:zr.outBuf.size]) - zr.outBuf.pos += C.size_t(n) + n, err := w.Write(zr.outBuf[zr.readerPos:]) + zr.readerPos += n nn += int64(n) if err != nil { return nn, err @@ -180,51 +186,68 @@ func (zr *Reader) Read(p []byte) (int, error) { return 0, nil } - if zr.outBuf.pos == zr.outBuf.size { - if err := zr.fillOutBuf(); err != nil { + if zr.readerPos >= len(zr.outBuf) { + if len(p) >= minDirectWriteBufferSize { + // write directly into the target buffer + // but make sure to override its capacity + return zr.fillOutBuf(p[:len(p):len(p)]) + } + if _, err := zr.fillOutBuf(nil); err != nil { return 0, err } + zr.readerPos = 0 } - n := copy(p, zr.outBufGo[zr.outBuf.pos:zr.outBuf.size]) - zr.outBuf.pos += C.size_t(n) + n := copy(p, zr.outBuf[zr.readerPos:]) + zr.readerPos += n return n, nil } -func (zr *Reader) fillOutBuf() error { - if zr.inBuf.pos == zr.inBuf.size && zr.outBuf.size < dstreamOutBufSize { +func (zr *Reader) fillOutBuf(target []byte) (int, error) { + dst := target + if dst == nil { + dst = zr.outBuf + } + + if int(zr.sizes.srcPos) == len(zr.inBuf) && !zr.skipNextRead { // inBuf is empty and the previously decompressed data size - // is smaller than the maximum possible zr.outBuf.size. + // is smaller than the maximum possible dst.size. // This means that the internal buffer in zr.ds doesn't contain // more data to decompress, so read new data into inBuf. if err := zr.fillInBuf(); err != nil { - return err + return 0, err } } + zr.sizes.dstSize = C.size_t(cap(dst)) + zr.sizes.dstPos = 0 + + srcBuf := unsafe.SliceData(zr.inBuf) + dstBuf := unsafe.SliceData(dst) tryDecompressAgain: + zr.sizes.srcSize = C.size_t(len(zr.inBuf)) + prevInBufPos := zr.sizes.srcPos + // Try decompressing inBuf into outBuf. - zr.outBuf.size = dstreamOutBufSize - zr.outBuf.pos = 0 - prevInBufPos := zr.inBuf.pos result := C.ZSTD_decompressStream_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(zr.ds))), - C.uintptr_t(uintptr(unsafe.Pointer(zr.outBuf))), - C.uintptr_t(uintptr(unsafe.Pointer(zr.inBuf)))) - zr.outBuf.size = zr.outBuf.pos - zr.outBuf.pos = 0 - - if C.ZSTD_getErrorCode(result) != 0 { - return fmt.Errorf("cannot decompress data: %s", errStr(result)) + unsafe.Pointer(zr.ds), unsafe.Pointer(dstBuf), unsafe.Pointer(srcBuf), &zr.sizes) + + zr.skipNextRead = int(zr.sizes.dstPos) == cap(dst) + if target == nil { + zr.outBuf = zr.outBuf[:zr.sizes.dstPos] + } + + if zstdIsError(result) { + return int(zr.sizes.dstPos), fmt.Errorf("cannot decompress data: %s", errStr(result)) } - if zr.outBuf.size > 0 { + if zr.sizes.dstPos > 0 { // Something has been decompressed to outBuf. Return it. - return nil + return int(zr.sizes.dstPos), nil } // Nothing has been decompressed from inBuf. - if zr.inBuf.pos != prevInBufPos && zr.inBuf.pos < zr.inBuf.size { + if zr.sizes.srcPos != prevInBufPos && int(zr.sizes.srcPos) < len(zr.inBuf) { // Data has been consumed from inBuf, but decompressed // into nothing. There is more data in inBuf, so try // decompressing it again. @@ -235,21 +258,31 @@ tryDecompressAgain: // decompressed into nothing and inBuf became empty. // Read more data into inBuf and try decompressing again. if err := zr.fillInBuf(); err != nil { - return err + return 0, err } + goto tryDecompressAgain } func (zr *Reader) fillInBuf() error { - // Copy the remaining data to the start of inBuf. - copy(zr.inBufGo[:dstreamInBufSize], zr.inBufGo[zr.inBuf.pos:zr.inBuf.size]) - zr.inBuf.size -= zr.inBuf.pos - zr.inBuf.pos = 0 + if zr.sizes.srcPos > 0 { + if int(zr.sizes.srcPos) == len(zr.inBuf) { + // we've read all the data from inBuf, reset it + zr.inBuf = zr.inBuf[:0] + zr.sizes.srcPos = 0 + } else if int(zr.sizes.srcPos) > cap(zr.inBuf)/2 { + // Copy the remaining data to the start of inBuf. + copy(zr.inBuf[:cap(zr.inBuf)], zr.inBuf[zr.sizes.srcPos:]) + zr.inBuf = zr.inBuf[:len(zr.inBuf)-int(zr.sizes.srcPos)] + zr.sizes.srcPos = 0 + } + } readAgain: // Read more data into inBuf. - n, err := zr.r.Read(zr.inBufGo[zr.inBuf.size:dstreamInBufSize]) - zr.inBuf.size += C.size_t(n) + n, err := zr.r.Read(zr.inBuf[len(zr.inBuf):cap(zr.inBuf)]) + zr.inBuf = zr.inBuf[:len(zr.inBuf)+n] + if err == nil { if n == 0 { // Nothing has been read. Try reading data again. @@ -265,5 +298,5 @@ readAgain: // Do not wrap io.EOF, so the caller may notify the end of stream. return err } - return fmt.Errorf("cannot read data from the underlying reader: %s", err) + return fmt.Errorf("cannot read data from the underlying reader: %w", err) } diff --git a/writer.go b/writer.go index 40e4b30..761dcff 100644 --- a/writer.go +++ b/writer.go @@ -7,45 +7,51 @@ package gozstd #include "zstd.h" #include "zstd_errors.h" -#include // for malloc/free -#include // for uintptr_t +typedef struct { + size_t dstSize; + size_t srcSize; + size_t dstPos; + size_t srcPos; +} ZSTD_EXT_BufferSizes; // The following *_wrapper functions allow avoiding memory allocations // durting calls from Go. // See https://github.com/golang/go/issues/24450 . - -static size_t ZSTD_CCtx_setParameter_wrapper(uintptr_t cs, ZSTD_cParameter param, int value) { - return ZSTD_CCtx_setParameter((ZSTD_CStream*)cs, param, value); +static size_t ZSTD_compressStream_wrapper(void *cs, void* dst, const void* src, ZSTD_EXT_BufferSizes* sizes, ZSTD_EndDirective endOp) { + return ZSTD_compressStream2_simpleArgs((ZSTD_CStream*)cs, dst, sizes->dstSize, &sizes->dstPos, src, sizes->srcSize, &sizes->srcPos, endOp); } -static size_t ZSTD_initCStream_wrapper(uintptr_t cs, int compressionLevel) { - return ZSTD_initCStream((ZSTD_CStream*)cs, compressionLevel); -} +static size_t ZSTD_flushStream_wrapper(void *cs, void *dst, ZSTD_EXT_BufferSizes* sizes) { + size_t res; + ZSTD_outBuffer outBuf; -static size_t ZSTD_CCtx_refCDict_wrapper(uintptr_t cc, uintptr_t dict) { - return ZSTD_CCtx_refCDict((ZSTD_CCtx*)cc, (ZSTD_CDict*)dict); -} + outBuf.dst = dst; + outBuf.size = sizes->dstSize; + outBuf.pos = sizes->dstPos; -static size_t ZSTD_freeCStream_wrapper(uintptr_t cs) { - return ZSTD_freeCStream((ZSTD_CStream*)cs); + res = ZSTD_flushStream((ZSTD_CStream*)cs, &outBuf); + sizes->dstPos = outBuf.pos; + return res; } -static size_t ZSTD_compressStream_wrapper(uintptr_t cs, uintptr_t output, uintptr_t input) { - return ZSTD_compressStream((ZSTD_CStream*)cs, (ZSTD_outBuffer*)output, (ZSTD_inBuffer*)input); -} +static size_t ZSTD_endStream_wrapper(void *cs, void *dst, ZSTD_EXT_BufferSizes* sizes) { + size_t res; + ZSTD_outBuffer outBuf; -static size_t ZSTD_flushStream_wrapper(uintptr_t cs, uintptr_t output) { - return ZSTD_flushStream((ZSTD_CStream*)cs, (ZSTD_outBuffer*)output); -} + outBuf.dst = dst; + outBuf.size = sizes->dstSize; + outBuf.pos = sizes->dstPos; -static size_t ZSTD_endStream_wrapper(uintptr_t cs, uintptr_t output) { - return ZSTD_endStream((ZSTD_CStream*)cs, (ZSTD_outBuffer*)output); + res = ZSTD_endStream((ZSTD_CStream*)cs, &outBuf); + sizes->dstPos = outBuf.pos; + return res; } */ import "C" import ( + "bytes" "fmt" "io" "runtime" @@ -57,8 +63,6 @@ var ( cstreamOutBufSize = C.ZSTD_CStreamOutSize() ) -type cMemPtr *[1 << 30]byte - // Writer implements zstd writer. type Writer struct { w io.Writer @@ -67,11 +71,12 @@ type Writer struct { cs *C.ZSTD_CStream cd *CDict - inBuf *C.ZSTD_inBuffer - outBuf *C.ZSTD_outBuffer + inBufWrapper *bytes.Buffer + outBufWrapper *bytes.Buffer - inBufGo cMemPtr - outBufGo cMemPtr + inBuf []byte + outBuf []byte + sizes C.ZSTD_EXT_BufferSizes } // NewWriter returns new zstd writer writing compressed data to w. @@ -160,15 +165,8 @@ func NewWriterParams(w io.Writer, params *WriterParams) *Writer { cs := C.ZSTD_createCStream() initCStream(cs, *params) - inBuf := (*C.ZSTD_inBuffer)(C.calloc(1, C.sizeof_ZSTD_inBuffer)) - inBuf.src = C.calloc(1, cstreamInBufSize) - inBuf.size = 0 - inBuf.pos = 0 - - outBuf := (*C.ZSTD_outBuffer)(C.calloc(1, C.sizeof_ZSTD_outBuffer)) - outBuf.dst = C.calloc(1, cstreamOutBufSize) - outBuf.size = cstreamOutBufSize - outBuf.pos = 0 + inBufWrapper := compInBufPool.Get().(*bytes.Buffer) + outBufWrapper := compOutBufPool.Get().(*bytes.Buffer) zw := &Writer{ w: w, @@ -176,13 +174,12 @@ func NewWriterParams(w io.Writer, params *WriterParams) *Writer { wlog: params.WindowLog, cs: cs, cd: params.Dict, - inBuf: inBuf, - outBuf: outBuf, + inBufWrapper: inBufWrapper, + outBufWrapper: outBufWrapper, + inBuf: inBufWrapper.Bytes(), + outBuf: outBufWrapper.Bytes(), } - zw.inBufGo = cMemPtr(zw.inBuf.src) - zw.outBufGo = cMemPtr(zw.outBuf.dst) - runtime.SetFinalizer(zw, freeCStream) return zw } @@ -201,10 +198,9 @@ func (zw *Writer) Reset(w io.Writer, cd *CDict, compressionLevel int) { // ResetWriterParams resets zw to write to w using the given set of parameters. func (zw *Writer) ResetWriterParams(w io.Writer, params *WriterParams) { - zw.inBuf.size = 0 - zw.inBuf.pos = 0 - zw.outBuf.size = cstreamOutBufSize - zw.outBuf.pos = 0 + zw.inBuf = zw.inBuf[:0] + zw.outBuf = zw.outBuf[:0] + zw.sizes = C.ZSTD_EXT_BufferSizes{} zw.cd = params.Dict initCStream(zw.cs, *params) @@ -214,21 +210,14 @@ func (zw *Writer) ResetWriterParams(w io.Writer, params *WriterParams) { func initCStream(cs *C.ZSTD_CStream, params WriterParams) { if params.Dict != nil { - result := C.ZSTD_CCtx_refCDict_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(cs))), - C.uintptr_t(uintptr(unsafe.Pointer(params.Dict.p)))) + result := C.ZSTD_CCtx_refCDict(cs, params.Dict.p) ensureNoError("ZSTD_CCtx_refCDict", result) } else { - result := C.ZSTD_initCStream_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(cs))), - C.int(params.CompressionLevel)) + result := C.ZSTD_initCStream(cs, C.int(params.CompressionLevel)) ensureNoError("ZSTD_initCStream", result) } - result := C.ZSTD_CCtx_setParameter_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(cs))), - C.ZSTD_cParameter(C.ZSTD_c_windowLog), - C.int(params.WindowLog)) + result := C.ZSTD_CCtx_setParameter(cs, C.ZSTD_c_windowLog, C.int(params.WindowLog)) ensureNoError("ZSTD_CCtx_setParameter", result) } @@ -244,21 +233,24 @@ func (zw *Writer) Release() { return } - result := C.ZSTD_freeCStream_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(zw.cs)))) + result := C.ZSTD_freeCStream(zw.cs) ensureNoError("ZSTD_freeCStream", result) zw.cs = nil - C.free(unsafe.Pointer(zw.inBuf.src)) - C.free(unsafe.Pointer(zw.inBuf)) - zw.inBuf = nil - - C.free(unsafe.Pointer(zw.outBuf.dst)) - C.free(unsafe.Pointer(zw.outBuf)) - zw.outBuf = nil - zw.w = nil zw.cd = nil + + if zw.inBufWrapper != nil { + zw.inBuf = nil + compInBufPool.Put(zw.inBufWrapper) + zw.inBufWrapper = nil + } + + if zw.outBufWrapper != nil { + zw.outBuf = nil + compOutBufPool.Put(zw.outBufWrapper) + zw.outBufWrapper = nil + } } // ReadFrom reads all the data from r and writes it to zw. @@ -272,13 +264,15 @@ func (zw *Writer) Release() { func (zw *Writer) ReadFrom(r io.Reader) (int64, error) { nn := int64(0) for { + inBuf := zw.inBuf[len(zw.inBuf):cap(zw.inBuf)] // Fill the inBuf. - for zw.inBuf.size < cstreamInBufSize { - n, err := r.Read(zw.inBufGo[zw.inBuf.size:cstreamInBufSize]) + for len(inBuf) > 0 { + n, err := r.Read(inBuf) // Sometimes n > 0 even when Read() returns an error. // This is true especially if the error is io.EOF. - zw.inBuf.size += C.size_t(n) + inBuf = inBuf[n:] + zw.inBuf = zw.inBuf[:len(zw.inBuf)+n] nn += int64(n) if err != nil { @@ -309,8 +303,8 @@ func (zw *Writer) Write(p []byte) (int, error) { } for { - n := copy(zw.inBufGo[zw.inBuf.size:cstreamInBufSize], p) - zw.inBuf.size += C.size_t(n) + n := copy(zw.inBuf[len(zw.inBuf):cap(zw.inBuf)], p) + zw.inBuf = zw.inBuf[:len(zw.inBuf)+n] p = p[n:] if len(p) == 0 { // Fast path - just copy the data to input buffer. @@ -323,19 +317,30 @@ func (zw *Writer) Write(p []byte) (int, error) { } func (zw *Writer) flushInBuf() error { - prevInBufPos := zw.inBuf.pos + zw.sizes.dstSize = C.size_t(cap(zw.outBuf)) + zw.sizes.dstPos = C.size_t(len(zw.outBuf)) + zw.sizes.srcSize = C.size_t(len(zw.inBuf)) + zw.sizes.srcPos = 0 + + dstBuf := unsafe.SliceData(zw.outBuf) + srcBuf := unsafe.SliceData(zw.inBuf) + result := C.ZSTD_compressStream_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(zw.cs))), - C.uintptr_t(uintptr(unsafe.Pointer(zw.outBuf))), - C.uintptr_t(uintptr(unsafe.Pointer(zw.inBuf)))) - ensureNoError("ZSTD_compressStream", result) + unsafe.Pointer(zw.cs), unsafe.Pointer(dstBuf), unsafe.Pointer(srcBuf), + &zw.sizes, C.ZSTD_e_continue) + ensureNoError("ZSTD_compressStream_wrapper", result) + + zw.outBuf = zw.outBuf[:zw.sizes.dstPos] // Move the remaining data to the start of inBuf. - copy(zw.inBufGo[:cstreamInBufSize], zw.inBufGo[zw.inBuf.pos:zw.inBuf.size]) - zw.inBuf.size -= zw.inBuf.pos - zw.inBuf.pos = 0 + if int(zw.sizes.srcPos) < len(zw.inBuf) { + copy(zw.inBuf[:cap(zw.inBuf)], zw.inBuf[zw.sizes.srcPos:len(zw.inBuf)]) + zw.inBuf = zw.inBuf[:len(zw.inBuf)-int(zw.sizes.srcPos)] + } else { + zw.inBuf = zw.inBuf[:0] + } - if zw.outBuf.size-zw.outBuf.pos > zw.outBuf.pos && prevInBufPos != zw.inBuf.pos { + if cap(zw.outBuf)-int(zw.sizes.dstPos) > int(zw.sizes.dstPos) && zw.sizes.srcPos > 0 { // There is enough space in outBuf and the last compression // succeeded, so don't flush outBuf yet. return nil @@ -347,20 +352,20 @@ func (zw *Writer) flushInBuf() error { } func (zw *Writer) flushOutBuf() error { - if zw.outBuf.pos == 0 { + if len(zw.outBuf) == 0 { // Nothing to flush. return nil } - outBuf := zw.outBufGo[:zw.outBuf.pos] - n, err := zw.w.Write(outBuf) - zw.outBuf.pos = 0 + bufLen := len(zw.outBuf) + n, err := zw.w.Write(zw.outBuf) + zw.outBuf = zw.outBuf[:0] if err != nil { return fmt.Errorf("cannot flush internal buffer to the underlying writer: %s", err) } - if n != len(outBuf) { + if n != bufLen { panic(fmt.Errorf("BUG: the underlying writer violated io.Writer contract and didn't return error after writing incomplete data; written %d bytes; want %d bytes", - n, len(outBuf))) + n, bufLen)) } return nil } @@ -368,7 +373,7 @@ func (zw *Writer) flushOutBuf() error { // Flush flushes the remaining data from zw to the underlying writer. func (zw *Writer) Flush() error { // Flush inBuf. - for zw.inBuf.size > 0 { + for len(zw.inBuf) > 0 { if err := zw.flushInBuf(); err != nil { return err } @@ -376,10 +381,14 @@ func (zw *Writer) Flush() error { // Flush the internal buffer to outBuf. for { + dstBuf := unsafe.SliceData(zw.outBuf) + zw.sizes.dstSize = C.size_t(cap(zw.outBuf)) + zw.sizes.dstPos = C.size_t(len(zw.outBuf)) + result := C.ZSTD_flushStream_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(zw.cs))), - C.uintptr_t(uintptr(unsafe.Pointer(zw.outBuf)))) + unsafe.Pointer(zw.cs), unsafe.Pointer(dstBuf), &zw.sizes) ensureNoError("ZSTD_flushStream", result) + zw.outBuf = zw.outBuf[:zw.sizes.dstPos] if err := zw.flushOutBuf(); err != nil { return err } @@ -400,10 +409,15 @@ func (zw *Writer) Close() error { } for { + dstBuf := unsafe.SliceData(zw.outBuf) + zw.sizes.dstSize = C.size_t(cap(zw.outBuf)) + zw.sizes.dstPos = C.size_t(len(zw.outBuf)) + result := C.ZSTD_endStream_wrapper( - C.uintptr_t(uintptr(unsafe.Pointer(zw.cs))), - C.uintptr_t(uintptr(unsafe.Pointer(zw.outBuf)))) + unsafe.Pointer(zw.cs), + unsafe.Pointer(dstBuf), &zw.sizes) ensureNoError("ZSTD_endStream", result) + zw.outBuf = zw.outBuf[:zw.sizes.dstPos] if err := zw.flushOutBuf(); err != nil { return err } diff --git a/writer_timing_test.go b/writer_timing_test.go index 66da8f4..0d1d328 100644 --- a/writer_timing_test.go +++ b/writer_timing_test.go @@ -2,7 +2,7 @@ package gozstd import ( "fmt" - "io/ioutil" + "io" "testing" ) @@ -26,7 +26,7 @@ func benchmarkWriterDict(b *testing.B, blockSize, level int) { b.ReportAllocs() b.SetBytes(int64(len(block))) b.RunParallel(func(pb *testing.PB) { - zw := NewWriterDict(ioutil.Discard, bd.cd) + zw := NewWriterDict(io.Discard, bd.cd) defer zw.Release() for pb.Next() { for i := 0; i < benchBlocksPerStream; i++ { @@ -38,7 +38,7 @@ func benchmarkWriterDict(b *testing.B, blockSize, level int) { if err := zw.Close(); err != nil { panic(fmt.Errorf("unexpected error: %s", err)) } - zw.Reset(ioutil.Discard, bd.cd, level) + zw.Reset(io.Discard, bd.cd, level) } }) } @@ -60,7 +60,7 @@ func benchmarkWriter(b *testing.B, blockSize, level int) { b.ReportAllocs() b.SetBytes(int64(len(block))) b.RunParallel(func(pb *testing.PB) { - zw := NewWriterLevel(ioutil.Discard, level) + zw := NewWriterLevel(io.Discard, level) defer zw.Release() for pb.Next() { for i := 0; i < benchBlocksPerStream; i++ { @@ -72,7 +72,7 @@ func benchmarkWriter(b *testing.B, blockSize, level int) { if err := zw.Close(); err != nil { panic(fmt.Errorf("unexpected error: %s", err)) } - zw.Reset(ioutil.Discard, nil, level) + zw.Reset(io.Discard, nil, level) } }) } @@ -82,11 +82,11 @@ func BenchmarkWriterResetAlloc(b *testing.B) { params := &WriterParams{} - zw := NewWriter(ioutil.Discard) + zw := NewWriter(io.Discard) defer zw.Release() for n := 0; n < b.N; n++ { - zw.Reset(ioutil.Discard, nil, 0) - zw.ResetWriterParams(ioutil.Discard, params) + zw.Reset(io.Discard, nil, 0) + zw.ResetWriterParams(io.Discard, params) } }