Signed-off-by: Miloslav Trmač <mitr@redhat.com>
This commit is contained in:
Miloslav Trmač
2023-01-23 16:22:08 +01:00
parent 47919520f5
commit 48b9d94c87
73 changed files with 11300 additions and 634 deletions

View File

@@ -192,16 +192,14 @@ func (b *blockDec) reset(br byteBuffer, windowSize uint64) error {
}
// Read block data.
if cap(b.dataStorage) < cSize {
if _, ok := br.(*byteBuf); !ok && cap(b.dataStorage) < cSize {
// byteBuf doesn't need a destination buffer.
if b.lowMem || cSize > maxCompressedBlockSize {
b.dataStorage = make([]byte, 0, cSize+compressedBlockOverAlloc)
} else {
b.dataStorage = make([]byte, 0, maxCompressedBlockSizeAlloc)
}
}
if cap(b.dst) <= maxSize {
b.dst = make([]byte, 0, maxSize+1)
}
b.data, err = br.readBig(cSize, b.dataStorage)
if err != nil {
if debugDecoder {
@@ -210,6 +208,9 @@ func (b *blockDec) reset(br byteBuffer, windowSize uint64) error {
}
return err
}
if cap(b.dst) <= maxSize {
b.dst = make([]byte, 0, maxSize+1)
}
return nil
}

View File

@@ -40,8 +40,7 @@ type Decoder struct {
frame *frameDec
// Custom dictionaries.
// Always uses copies.
dicts map[uint32]dict
dicts map[uint32]*dict
// streamWg is the waitgroup for all streams
streamWg sync.WaitGroup
@@ -103,7 +102,7 @@ func NewReader(r io.Reader, opts ...DOption) (*Decoder, error) {
}
// Transfer option dicts.
d.dicts = make(map[uint32]dict, len(d.o.dicts))
d.dicts = make(map[uint32]*dict, len(d.o.dicts))
for _, dc := range d.o.dicts {
d.dicts[dc.id] = dc
}
@@ -341,15 +340,8 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
}
return dst, err
}
if frame.DictionaryID != nil {
dict, ok := d.dicts[*frame.DictionaryID]
if !ok {
return nil, ErrUnknownDictionary
}
if debugDecoder {
println("setting dict", frame.DictionaryID)
}
frame.history.setDict(&dict)
if err = d.setDict(frame); err != nil {
return nil, err
}
if frame.WindowSize > d.o.maxWindowSize {
if debugDecoder {
@@ -495,18 +487,12 @@ func (d *Decoder) nextBlockSync() (ok bool) {
if !d.syncStream.inFrame {
d.frame.history.reset()
d.current.err = d.frame.reset(&d.syncStream.br)
if d.current.err == nil {
d.current.err = d.setDict(d.frame)
}
if d.current.err != nil {
return false
}
if d.frame.DictionaryID != nil {
dict, ok := d.dicts[*d.frame.DictionaryID]
if !ok {
d.current.err = ErrUnknownDictionary
return false
} else {
d.frame.history.setDict(&dict)
}
}
if d.frame.WindowSize > d.o.maxDecodedSize || d.frame.WindowSize > d.o.maxWindowSize {
d.current.err = ErrDecoderSizeExceeded
return false
@@ -865,13 +851,8 @@ decodeStream:
if debugDecoder && err != nil {
println("Frame decoder returned", err)
}
if err == nil && frame.DictionaryID != nil {
dict, ok := d.dicts[*frame.DictionaryID]
if !ok {
err = ErrUnknownDictionary
} else {
frame.history.setDict(&dict)
}
if err == nil {
err = d.setDict(frame)
}
if err == nil && d.frame.WindowSize > d.o.maxWindowSize {
if debugDecoder {
@@ -953,3 +934,20 @@ decodeStream:
hist.reset()
d.frame.history.b = frameHistCache
}
func (d *Decoder) setDict(frame *frameDec) (err error) {
dict, ok := d.dicts[frame.DictionaryID]
if ok {
if debugDecoder {
println("setting dict", frame.DictionaryID)
}
frame.history.setDict(dict)
} else if frame.DictionaryID != 0 {
// A zero or missing dictionary id is ambiguous:
// either dictionary zero, or no dictionary. In particular,
// zstd --patch-from uses this id for the source file,
// so only return an error if the dictionary id is not zero.
err = ErrUnknownDictionary
}
return err
}

View File

@@ -6,6 +6,8 @@ package zstd
import (
"errors"
"fmt"
"math/bits"
"runtime"
)
@@ -18,7 +20,7 @@ type decoderOptions struct {
concurrent int
maxDecodedSize uint64
maxWindowSize uint64
dicts []dict
dicts []*dict
ignoreChecksum bool
limitToCap bool
decodeBufsBelow int
@@ -85,7 +87,13 @@ func WithDecoderMaxMemory(n uint64) DOption {
}
// WithDecoderDicts allows to register one or more dictionaries for the decoder.
// If several dictionaries with the same ID is provided the last one will be used.
//
// Each slice in dict must be in the [dictionary format] produced by
// "zstd --train" from the Zstandard reference implementation.
//
// If several dictionaries with the same ID are provided, the last one will be used.
//
// [dictionary format]: https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary-format
func WithDecoderDicts(dicts ...[]byte) DOption {
return func(o *decoderOptions) error {
for _, b := range dicts {
@@ -93,12 +101,24 @@ func WithDecoderDicts(dicts ...[]byte) DOption {
if err != nil {
return err
}
o.dicts = append(o.dicts, *d)
o.dicts = append(o.dicts, d)
}
return nil
}
}
// WithEncoderDictRaw registers a dictionary that may be used by the decoder.
// The slice content can be arbitrary data.
func WithDecoderDictRaw(id uint32, content []byte) DOption {
return func(o *decoderOptions) error {
if bits.UintSize > 32 && uint(len(content)) > dictMaxLength {
return fmt.Errorf("dictionary of size %d > 2GiB too large", len(content))
}
o.dicts = append(o.dicts, &dict{id: id, content: content, offsets: [3]int{1, 4, 8}})
return nil
}
}
// WithDecoderMaxWindow allows to set a maximum window size for decodes.
// This allows rejecting packets that will cause big memory usage.
// The Decoder will likely allocate more memory based on the WithDecoderLowmem setting.

View File

@@ -21,6 +21,9 @@ type dict struct {
const dictMagic = "\x37\xa4\x30\xec"
// Maximum dictionary size for the reference implementation (1.5.3) is 2 GiB.
const dictMaxLength = 1 << 31
// ID returns the dictionary id or 0 if d is nil.
func (d *dict) ID() uint32 {
if d == nil {

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"math"
"math/bits"
"runtime"
"strings"
)
@@ -305,7 +306,13 @@ func WithLowerEncoderMem(b bool) EOption {
}
// WithEncoderDict allows to register a dictionary that will be used for the encode.
//
// The slice dict must be in the [dictionary format] produced by
// "zstd --train" from the Zstandard reference implementation.
//
// The encoder *may* choose to use no dictionary instead for certain payloads.
//
// [dictionary format]: https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary-format
func WithEncoderDict(dict []byte) EOption {
return func(o *encoderOptions) error {
d, err := loadDict(dict)
@@ -316,3 +323,17 @@ func WithEncoderDict(dict []byte) EOption {
return nil
}
}
// WithEncoderDictRaw registers a dictionary that may be used by the encoder.
//
// The slice content may contain arbitrary data. It will be used as an initial
// history.
func WithEncoderDictRaw(id uint32, content []byte) EOption {
return func(o *encoderOptions) error {
if bits.UintSize > 32 && uint(len(content)) > dictMaxLength {
return fmt.Errorf("dictionary of size %d > 2GiB too large", len(content))
}
o.dict = &dict{id: id, content: content, offsets: [3]int{1, 4, 8}}
return nil
}
}

View File

@@ -29,7 +29,7 @@ type frameDec struct {
FrameContentSize uint64
DictionaryID *uint32
DictionaryID uint32
HasCheckSum bool
SingleSegment bool
}
@@ -155,7 +155,7 @@ func (d *frameDec) reset(br byteBuffer) error {
// Read Dictionary_ID
// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary_id
d.DictionaryID = nil
d.DictionaryID = 0
if size := fhd & 3; size != 0 {
if size == 3 {
size = 4
@@ -178,11 +178,7 @@ func (d *frameDec) reset(br byteBuffer) error {
if debugDecoder {
println("Dict size", size, "ID:", id)
}
if id > 0 {
// ID 0 means "sorry, no dictionary anyway".
// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary-format
d.DictionaryID = &id
}
d.DictionaryID = id
}
// Read Frame_Content_Size

View File

@@ -72,7 +72,6 @@ var (
ErrDecoderSizeExceeded = errors.New("decompressed size exceeds configured limit")
// ErrUnknownDictionary is returned if the dictionary ID is unknown.
// For the time being dictionaries are not supported.
ErrUnknownDictionary = errors.New("unknown dictionary")
// ErrFrameSizeExceeded is returned if the stated frame size is exceeded.