From e685b927a108977b99cdfbb475e23b4b551b1ca3 Mon Sep 17 00:00:00 2001 From: gucio321 Date: Sun, 16 May 2021 12:27:05 +0200 Subject: [PATCH] bitMuncher: add EnusreBits method --- d2common/d2data/d2compression/huffman.go | 10 ++++------ d2common/d2datautils/bitmuncher.go | 11 +++++++++++ d2common/d2datautils/bitmuncher_test.go | 22 +++++++++++++++++----- d2common/d2datautils/stream_writer.go | 8 ++++---- d2common/d2datautils/stream_writer_test.go | 6 +++--- 5 files changed, 39 insertions(+), 18 deletions(-) diff --git a/d2common/d2data/d2compression/huffman.go b/d2common/d2data/d2compression/huffman.go index c4c2f6f4..967af892 100644 --- a/d2common/d2data/d2compression/huffman.go +++ b/d2common/d2data/d2compression/huffman.go @@ -210,12 +210,10 @@ func decode(input *d2datautils.BitMuncher, head *linkedNode) *linkedNode { node := head for node.child0 != nil { - // checks if GetBit causes panic (End of file) - defer func() { - if r := recover(); r != nil { - log.Fatal("HuffmanDecompress: Unexpected end of file") - } - }() + if !input.EnsureBits(1) { + log.Fatal("Unexpected end of file") + } + bit := input.GetBit() if bit == 0 { node = node.child0 diff --git a/d2common/d2datautils/bitmuncher.go b/d2common/d2datautils/bitmuncher.go index 7352970e..83810158 100644 --- a/d2common/d2datautils/bitmuncher.go +++ b/d2common/d2datautils/bitmuncher.go @@ -136,3 +136,14 @@ func (v *BitMuncher) MakeSigned(value uint32, bits int) int32 { // Force casting to a signed value return int32(result) } + +// EnsureBits checks, if `count` bits is available +func (v *BitMuncher) EnsureBits(count int) bool { + bytesRead := v.offset / byteLen + bitOffset := v.offset % byteLen + numBytes := len(v.data) + remainingBytes := numBytes - bytesRead + remainingBits := remainingBytes*byteLen - bitOffset + + return count <= remainingBits +} diff --git a/d2common/d2datautils/bitmuncher_test.go b/d2common/d2datautils/bitmuncher_test.go index d8cfaf57..6141f81f 100644 --- a/d2common/d2datautils/bitmuncher_test.go +++ b/d2common/d2datautils/bitmuncher_test.go @@ -36,7 +36,7 @@ func TestBitmuncherReadBit(t *testing.T) { var result byte - for i := 0; i < bitsPerByte; i++ { + for i := 0; i < byteLen; i++ { v := bm.GetBit() result |= byte(v) << byte(i) } @@ -47,7 +47,7 @@ func TestBitmuncherReadBit(t *testing.T) { func TestBitmuncherGetBits(t *testing.T) { bm := CreateBitMuncher(testData, 0) - assert.Equal(t, byte(bm.GetBits(bitsPerByte)), testData[0], "get bits didn't return expected value") + assert.Equal(t, byte(bm.GetBits(byteLen)), testData[0], "get bits didn't return expected value") } func TestBitmuncherGetNoBits(t *testing.T) { @@ -77,7 +77,7 @@ func TestBitmuncherGetOneSignedBit(t *testing.T) { func TestBitmuncherSkipBits(t *testing.T) { bm := CreateBitMuncher(testData, 0) - bm.SkipBits(bitsPerByte) + bm.SkipBits(byteLen) assert.Equal(t, bm.GetByte(), testData[1], "skipping 8 bits didn't moved bit muncher's position into next byte") } @@ -88,7 +88,7 @@ func TestBitmuncherGetInt32(t *testing.T) { var testInt int32 for i := 0; i < bytesPerint32; i++ { - testInt |= int32(testData[i]) << int32(bitsPerByte*i) + testInt |= int32(testData[i]) << int32(byteLen*i) } assert.Equal(t, bm.GetInt32(), testInt, "int32 value wasn't returned properly") @@ -100,8 +100,20 @@ func TestBitmuncherGetUint32(t *testing.T) { var testUint uint32 for i := 0; i < bytesPerint32; i++ { - testUint |= uint32(testData[i]) << uint32(bitsPerByte*i) + testUint |= uint32(testData[i]) << uint32(byteLen*i) } assert.Equal(t, bm.GetUInt32(), testUint, "uint32 value wasn't returned properly") } + +func TestBitMuncherEnsureBits(t *testing.T) { + bm := CreateBitMuncher(testData, 0) + + assert.Equal(t, true, bm.EnsureBits(byteLen*len(testData)), "unexpected value returned by EnsureBits") + assert.Equal(t, false, bm.EnsureBits(byteLen*len(testData)+1), "unexpected value returned by EnsureBits") + + bm.SkipBits(5) + + assert.Equal(t, true, bm.EnsureBits(byteLen*len(testData)-5), "unexpected value returned by EnsureBits") + assert.Equal(t, false, bm.EnsureBits(byteLen*len(testData)-4), "unexpected value returned by EnsureBits") +} diff --git a/d2common/d2datautils/stream_writer.go b/d2common/d2datautils/stream_writer.go index 130fde63..03afe450 100644 --- a/d2common/d2datautils/stream_writer.go +++ b/d2common/d2datautils/stream_writer.go @@ -42,7 +42,7 @@ func (v *StreamWriter) PushBit(b bool) { } v.bitOffset++ - if v.bitOffset != bitsPerByte { + if v.bitOffset != byteLen { return } @@ -53,7 +53,7 @@ func (v *StreamWriter) PushBit(b bool) { // PushBits pushes bits (with max range 8) func (v *StreamWriter) PushBits(b byte, bits int) { - if bits > bitsPerByte { + if bits > byteLen { log.Print("input bits number must be less (or equal) than 8") } @@ -67,7 +67,7 @@ func (v *StreamWriter) PushBits(b byte, bits int) { // PushBits16 pushes bits (with max range 16) func (v *StreamWriter) PushBits16(b uint16, bits int) { - if bits > bitsPerByte*bytesPerint16 { + if bits > byteLen*bytesPerint16 { log.Print("input bits number must be less (or equal) than 16") } @@ -81,7 +81,7 @@ func (v *StreamWriter) PushBits16(b uint16, bits int) { // PushBits32 pushes bits (with max range 32) func (v *StreamWriter) PushBits32(b uint32, bits int) { - if bits > bitsPerByte*bytesPerint32 { + if bits > byteLen*bytesPerint32 { log.Print("input bits number must be less (or equal) than 32") } diff --git a/d2common/d2datautils/stream_writer_test.go b/d2common/d2datautils/stream_writer_test.go index 68c22cda..6a3cf709 100644 --- a/d2common/d2datautils/stream_writer_test.go +++ b/d2common/d2datautils/stream_writer_test.go @@ -9,7 +9,7 @@ func TestStreamWriterBits(t *testing.T) { data := []byte{221, 19} for _, i := range data { - sr.PushBits(i, bitsPerByte) + sr.PushBits(i, byteLen) } output := sr.GetBytes() @@ -25,7 +25,7 @@ func TestStreamWriterBits16(t *testing.T) { data := []uint16{1024, 19} for _, i := range data { - sr.PushBits16(i, bitsPerByte*bytesPerint16) + sr.PushBits16(i, byteLen*bytesPerint16) } output := sr.GetBytes() @@ -45,7 +45,7 @@ func TestStreamWriterBits32(t *testing.T) { data := []uint32{19324, 87} for _, i := range data { - sr.PushBits32(i, bitsPerByte*bytesPerint32) + sr.PushBits32(i, byteLen*bytesPerint32) } output := sr.GetBytes()