// Copyright 2024 The Gitea Authors. All rights reserved. // SPDX-License-Identifier: MIT // Package zstd provides a high-level API for reading and writing zstd-compressed data. // It supports both regular and seekable zstd streams. // It's not a new wheel, but a wrapper around the zstd and zstd-seekable-format-go packages. package zstd import ( "errors" "io" seekable "github.com/SaveTheRbtz/zstd-seekable-format-go/pkg" "github.com/klauspost/compress/zstd" ) type Writer zstd.Encoder var _ io.WriteCloser = (*Writer)(nil) // NewWriter returns a new zstd writer. func NewWriter(w io.Writer, opts ...WriterOption) (*Writer, error) { zstdW, err := zstd.NewWriter(w, opts...) if err != nil { return nil, err } return (*Writer)(zstdW), nil } func (w *Writer) Write(p []byte) (int, error) { return (*zstd.Encoder)(w).Write(p) } func (w *Writer) Close() error { return (*zstd.Encoder)(w).Close() } type Reader zstd.Decoder var _ io.ReadCloser = (*Reader)(nil) // NewReader returns a new zstd reader. func NewReader(r io.Reader, opts ...ReaderOption) (*Reader, error) { zstdR, err := zstd.NewReader(r, opts...) if err != nil { return nil, err } return (*Reader)(zstdR), nil } func (r *Reader) Read(p []byte) (int, error) { return (*zstd.Decoder)(r).Read(p) } func (r *Reader) Close() error { (*zstd.Decoder)(r).Close() // no error returned return nil } type SeekableWriter struct { buf []byte n int w seekable.Writer } var _ io.WriteCloser = (*SeekableWriter)(nil) // NewSeekableWriter returns a zstd writer to compress data to seekable format. // blockSize is an important parameter, it should be decided according to the actual business requirements. // If it's too small, the compression ratio could be very bad, even no compression at all. // If it's too large, it could cost more traffic when reading the data partially from underlying storage. func NewSeekableWriter(w io.Writer, blockSize int, opts ...WriterOption) (*SeekableWriter, error) { zstdW, err := zstd.NewWriter(nil, opts...) if err != nil { return nil, err } seekableW, err := seekable.NewWriter(w, zstdW) if err != nil { return nil, err } return &SeekableWriter{ buf: make([]byte, blockSize), w: seekableW, }, nil } func (w *SeekableWriter) Write(p []byte) (int, error) { written := 0 for len(p) > 0 { n := copy(w.buf[w.n:], p) w.n += n written += n p = p[n:] if w.n == len(w.buf) { if _, err := w.w.Write(w.buf); err != nil { return written, err } w.n = 0 } } return written, nil } func (w *SeekableWriter) Close() error { if w.n > 0 { if _, err := w.w.Write(w.buf[:w.n]); err != nil { return err } } return w.w.Close() } type SeekableReader struct { r seekable.Reader c func() error } var _ io.ReadSeekCloser = (*SeekableReader)(nil) // NewSeekableReader returns a zstd reader to decompress data from seekable format. func NewSeekableReader(r io.ReadSeeker, opts ...ReaderOption) (*SeekableReader, error) { zstdR, err := zstd.NewReader(nil, opts...) if err != nil { return nil, err } seekableR, err := seekable.NewReader(r, zstdR) if err != nil { return nil, err } ret := &SeekableReader{ r: seekableR, } if closer, ok := r.(io.Closer); ok { ret.c = closer.Close } return ret, nil } func (r *SeekableReader) Read(p []byte) (int, error) { return r.r.Read(p) } func (r *SeekableReader) Seek(offset int64, whence int) (int64, error) { return r.r.Seek(offset, whence) } func (r *SeekableReader) Close() error { return errors.Join( func() error { if r.c != nil { return r.c() } return nil }(), r.r.Close(), ) }