diff --git a/source/ByteBuffer.cpp b/source/ByteBuffer.cpp index 623b865f4..724745855 100644 --- a/source/ByteBuffer.cpp +++ b/source/ByteBuffer.cpp @@ -39,6 +39,7 @@ cByteBuffer::cByteBuffer(int a_BufferSize) : cByteBuffer::~cByteBuffer() { + CheckValid(); delete[] m_Buffer; } @@ -48,7 +49,9 @@ cByteBuffer::~cByteBuffer() bool cByteBuffer::Write(const char * a_Bytes, int a_Count) { - // DEBUG: Store the current free space for a check after writing + CheckValid(); + + // Store the current free space for a check after writing: int CurFreeSpace = GetFreeSpace(); int CurReadableSpace = GetReadableSpace(); int WrittenBytes = 0; @@ -58,20 +61,26 @@ bool cByteBuffer::Write(const char * a_Bytes, int a_Count) return false; } int TillEnd = m_BufferSize - m_WritePos; - if (TillEnd < a_Count) + if (TillEnd <= a_Count) { // Need to wrap around the ringbuffer end - memcpy(m_Buffer + m_WritePos, a_Bytes, TillEnd); + if (TillEnd > 0) + { + memcpy(m_Buffer + m_WritePos, a_Bytes, TillEnd); + a_Bytes += TillEnd; + a_Count -= TillEnd; + WrittenBytes = TillEnd; + } m_WritePos = 0; - a_Bytes += TillEnd; - a_Count -= TillEnd; - WrittenBytes = TillEnd; } // We're guaranteed that we'll fit in a single write op - memcpy(m_Buffer + m_WritePos, a_Bytes, a_Count); - m_WritePos += a_Count; - WrittenBytes += a_Count; + if (a_Count > 0) + { + memcpy(m_Buffer + m_WritePos, a_Bytes, a_Count); + m_WritePos += a_Count; + WrittenBytes += a_Count; + } ASSERT(GetFreeSpace() == CurFreeSpace - WrittenBytes); ASSERT(GetReadableSpace() == CurReadableSpace + WrittenBytes); @@ -84,6 +93,7 @@ bool cByteBuffer::Write(const char * a_Bytes, int a_Count) int cByteBuffer::GetFreeSpace(void) const { + CheckValid(); if (m_WritePos >= m_DataStart) { // Wrap around the buffer end: @@ -100,6 +110,7 @@ int cByteBuffer::GetFreeSpace(void) const /// Returns the number of bytes that are currently in the ringbuffer. Note GetReadableBytes() int cByteBuffer::GetUsedSpace(void) const { + CheckValid(); return m_BufferSize - GetFreeSpace(); } @@ -110,6 +121,7 @@ int cByteBuffer::GetUsedSpace(void) const /// Returns the number of bytes that are currently available for reading (may be less than UsedSpace due to some data having been read already) int cByteBuffer::GetReadableSpace(void) const { + CheckValid(); if (m_ReadPos > m_WritePos) { // Wrap around the buffer end: @@ -125,6 +137,7 @@ int cByteBuffer::GetReadableSpace(void) const bool cByteBuffer::CanReadBytes(int a_Count) const { + CheckValid(); return (a_Count <= GetReadableSpace()); } @@ -134,6 +147,7 @@ bool cByteBuffer::CanReadBytes(int a_Count) const bool cByteBuffer::CanWriteBytes(int a_Count) const { + CheckValid(); return (a_Count <= GetFreeSpace()); } @@ -143,6 +157,7 @@ bool cByteBuffer::CanWriteBytes(int a_Count) const bool cByteBuffer::ReadChar(char & a_Value) { + CheckValid(); NEEDBYTES(1); ReadBuf(&a_Value, 1); return true; @@ -154,6 +169,7 @@ bool cByteBuffer::ReadChar(char & a_Value) bool cByteBuffer::ReadByte(unsigned char & a_Value) { + CheckValid(); NEEDBYTES(1); ReadBuf(&a_Value, 1); return true; @@ -165,6 +181,7 @@ bool cByteBuffer::ReadByte(unsigned char & a_Value) bool cByteBuffer::ReadBEShort(short & a_Value) { + CheckValid(); NEEDBYTES(2); ReadBuf(&a_Value, 2); a_Value = ntohs(a_Value); @@ -177,6 +194,7 @@ bool cByteBuffer::ReadBEShort(short & a_Value) bool cByteBuffer::ReadBEInt(int & a_Value) { + CheckValid(); NEEDBYTES(4); ReadBuf(&a_Value, 4); a_Value = ntohl(a_Value); @@ -189,6 +207,7 @@ bool cByteBuffer::ReadBEInt(int & a_Value) bool cByteBuffer::ReadBEInt64(Int64 & a_Value) { + CheckValid(); NEEDBYTES(8); ReadBuf(&a_Value, 8); a_Value = NetworkToHostLong8(&a_Value); @@ -201,6 +220,7 @@ bool cByteBuffer::ReadBEInt64(Int64 & a_Value) bool cByteBuffer::ReadBEFloat(float & a_Value) { + CheckValid(); NEEDBYTES(4); ReadBuf(&a_Value, 4); a_Value = NetworkToHostFloat4(&a_Value); @@ -213,6 +233,7 @@ bool cByteBuffer::ReadBEFloat(float & a_Value) bool cByteBuffer::ReadBEDouble(double & a_Value) { + CheckValid(); NEEDBYTES(8); ReadBuf(&a_Value, 8); a_Value = NetworkToHostDouble8(&a_Value); @@ -225,6 +246,7 @@ bool cByteBuffer::ReadBEDouble(double & a_Value) bool cByteBuffer::ReadBool(bool & a_Value) { + CheckValid(); NEEDBYTES(1); a_Value = (m_Buffer[m_ReadPos++] != 0); return true; @@ -236,6 +258,7 @@ bool cByteBuffer::ReadBool(bool & a_Value) bool cByteBuffer::ReadBEUTF16String16(AString & a_Value) { + CheckValid(); short Length; if (!ReadBEShort(Length)) { @@ -250,6 +273,7 @@ bool cByteBuffer::ReadBEUTF16String16(AString & a_Value) bool cByteBuffer::WriteChar(char a_Value) { + CheckValid(); PUTBYTES(1); return WriteBuf(&a_Value, 1); } @@ -260,6 +284,7 @@ bool cByteBuffer::WriteChar(char a_Value) bool cByteBuffer::WriteByte(unsigned char a_Value) { + CheckValid(); PUTBYTES(1); return WriteBuf(&a_Value, 1); } @@ -270,6 +295,7 @@ bool cByteBuffer::WriteByte(unsigned char a_Value) bool cByteBuffer::WriteBEShort(short a_Value) { + CheckValid(); PUTBYTES(2); short Converted = htons(a_Value); return WriteBuf(&Converted, 2); @@ -281,6 +307,7 @@ bool cByteBuffer::WriteBEShort(short a_Value) bool cByteBuffer::WriteBEInt(int a_Value) { + CheckValid(); PUTBYTES(4); int Converted = HostToNetwork4(&a_Value); return WriteBuf(&Converted, 4); @@ -292,6 +319,7 @@ bool cByteBuffer::WriteBEInt(int a_Value) bool cByteBuffer::WriteBEInt64(Int64 a_Value) { + CheckValid(); PUTBYTES(8); Int64 Converted = HostToNetwork8(&a_Value); return WriteBuf(&Converted, 8); @@ -303,6 +331,7 @@ bool cByteBuffer::WriteBEInt64(Int64 a_Value) bool cByteBuffer::WriteBEFloat(float a_Value) { + CheckValid(); PUTBYTES(4); int Converted = HostToNetwork4(&a_Value); return WriteBuf(&Converted, 4); @@ -314,6 +343,7 @@ bool cByteBuffer::WriteBEFloat(float a_Value) bool cByteBuffer::WriteBEDouble(double a_Value) { + CheckValid(); PUTBYTES(8); Int64 Converted = HostToNetwork8(&a_Value); return WriteBuf(&Converted, 8); @@ -326,6 +356,7 @@ bool cByteBuffer::WriteBEDouble(double a_Value) bool cByteBuffer::WriteBool(bool a_Value) { + CheckValid(); return WriteChar(a_Value ? 1 : 0); } @@ -335,6 +366,7 @@ bool cByteBuffer::WriteBool(bool a_Value) bool cByteBuffer::WriteBEUTF16String16(const AString & a_Value) { + CheckValid(); PUTBYTES(2); AString UTF16BE; UTF8ToRawBEUTF16(a_Value.data(), a_Value.size(), UTF16BE); @@ -350,12 +382,13 @@ bool cByteBuffer::WriteBEUTF16String16(const AString & a_Value) bool cByteBuffer::ReadBuf(void * a_Buffer, int a_Count) { + CheckValid(); ASSERT(a_Count >= 0); NEEDBYTES(a_Count); char * Dst = (char *)a_Buffer; // So that we can do byte math int BytesToEndOfBuffer = m_BufferSize - m_ReadPos; ASSERT(BytesToEndOfBuffer >= 0); // Sanity check - if (BytesToEndOfBuffer < a_Count) + if (BytesToEndOfBuffer <= a_Count) { // Reading across the ringbuffer end, read the first part and adjust parameters: if (BytesToEndOfBuffer > 0) @@ -368,8 +401,11 @@ bool cByteBuffer::ReadBuf(void * a_Buffer, int a_Count) } // Read the rest of the bytes in a single read (guaranteed to fit): - memcpy(Dst, m_Buffer + m_ReadPos, a_Count); - m_ReadPos += a_Count; + if (a_Count > 0) + { + memcpy(Dst, m_Buffer + m_ReadPos, a_Count); + m_ReadPos += a_Count; + } return true; } @@ -379,11 +415,12 @@ bool cByteBuffer::ReadBuf(void * a_Buffer, int a_Count) bool cByteBuffer::WriteBuf(const void * a_Buffer, int a_Count) { + CheckValid(); ASSERT(a_Count >= 0); PUTBYTES(a_Count); char * Src = (char *)a_Buffer; // So that we can do byte math int BytesToEndOfBuffer = m_BufferSize - m_WritePos; - if (BytesToEndOfBuffer < a_Count) + if (BytesToEndOfBuffer <= a_Count) { // Reading across the ringbuffer end, read the first part and adjust parameters: memcpy(m_Buffer + m_WritePos, Src, BytesToEndOfBuffer); @@ -393,8 +430,11 @@ bool cByteBuffer::WriteBuf(const void * a_Buffer, int a_Count) } // Read the rest of the bytes in a single read (guaranteed to fit): - memcpy(m_Buffer + m_WritePos, Src, a_Count); - m_WritePos += a_Count; + if (a_Count > 0) + { + memcpy(m_Buffer + m_WritePos, Src, a_Count); + m_WritePos += a_Count; + } return true; } @@ -404,22 +444,30 @@ bool cByteBuffer::WriteBuf(const void * a_Buffer, int a_Count) bool cByteBuffer::ReadString(AString & a_String, int a_Count) { + CheckValid(); ASSERT(a_Count >= 0); NEEDBYTES(a_Count); a_String.clear(); a_String.reserve(a_Count); int BytesToEndOfBuffer = m_BufferSize - m_ReadPos; - if (BytesToEndOfBuffer < a_Count) + ASSERT(BytesToEndOfBuffer >= 0); // Sanity check + if (BytesToEndOfBuffer <= a_Count) { // Reading across the ringbuffer end, read the first part and adjust parameters: - a_String.assign(m_Buffer + m_ReadPos, BytesToEndOfBuffer); - a_Count -= BytesToEndOfBuffer; + if (BytesToEndOfBuffer > 0) + { + a_String.assign(m_Buffer + m_ReadPos, BytesToEndOfBuffer); + a_Count -= BytesToEndOfBuffer; + } m_ReadPos = 0; } - + // Read the rest of the bytes in a single read (guaranteed to fit): - a_String.append(m_Buffer + m_ReadPos, a_Count); - m_ReadPos += a_Count; + if (a_Count > 0) + { + a_String.append(m_Buffer + m_ReadPos, a_Count); + m_ReadPos += a_Count; + } return true; } @@ -430,6 +478,7 @@ bool cByteBuffer::ReadString(AString & a_String, int a_Count) bool cByteBuffer::ReadUTF16String(AString & a_String, int a_NumChars) { // Reads 2 * a_NumChars bytes and interprets it as a UTF16 string, converting it into UTF8 string a_String + CheckValid(); ASSERT(a_NumChars >= 0); AString RawData; if (!ReadString(RawData, a_NumChars * 2)) @@ -446,6 +495,7 @@ bool cByteBuffer::ReadUTF16String(AString & a_String, int a_NumChars) bool cByteBuffer::SkipRead(int a_Count) { + CheckValid(); ASSERT(a_Count >= 0); if (!CanReadBytes(a_Count)) { @@ -461,6 +511,7 @@ bool cByteBuffer::SkipRead(int a_Count) void cByteBuffer::ReadAll(AString & a_Data) { + CheckValid(); ReadString(a_Data, GetReadableSpace()); } @@ -470,6 +521,7 @@ void cByteBuffer::ReadAll(AString & a_Data) void cByteBuffer::CommitRead(void) { + CheckValid(); m_DataStart = m_ReadPos; } @@ -479,6 +531,7 @@ void cByteBuffer::CommitRead(void) void cByteBuffer::ResetRead(void) { + CheckValid(); m_ReadPos = m_DataStart; } @@ -490,6 +543,7 @@ void cByteBuffer::ReadAgain(AString & a_Out) { // Return the data between m_DataStart and m_ReadPos (the data that has been read but not committed) // Used by ProtoProxy to repeat communication twice, once for parsing and the other time for the remote party + CheckValid(); int DataStart = m_DataStart; if (m_ReadPos < m_DataStart) { @@ -506,6 +560,7 @@ void cByteBuffer::ReadAgain(AString & a_Out) void cByteBuffer::AdvanceReadPos(int a_Count) { + CheckValid(); m_ReadPos += a_Count; if (m_ReadPos > m_BufferSize) { @@ -516,3 +571,15 @@ void cByteBuffer::AdvanceReadPos(int a_Count) + +void cByteBuffer::CheckValid(void) const +{ + ASSERT(m_ReadPos >= 0); + ASSERT(m_ReadPos < m_BufferSize); + ASSERT(m_WritePos >= 0); + ASSERT(m_WritePos < m_BufferSize); +} + + + + diff --git a/source/ByteBuffer.h b/source/ByteBuffer.h index 3981ab066..2a73ed597 100644 --- a/source/ByteBuffer.h +++ b/source/ByteBuffer.h @@ -97,6 +97,9 @@ public: /// Re-reads the data that has been read since the last commit to the current readpos. Used by ProtoProxy to duplicate communication void ReadAgain(AString & a_Out); + /// Checks if the internal state is valid (read and write positions in the correct bounds) using ASSERTs + void CheckValid(void) const; + protected: char * m_Buffer; int m_BufferSize; // Total size of the ringbuffer diff --git a/source/Protocol/Protocol125.cpp b/source/Protocol/Protocol125.cpp index 58afcdcae..beac46c69 100644 --- a/source/Protocol/Protocol125.cpp +++ b/source/Protocol/Protocol125.cpp @@ -96,8 +96,10 @@ enum { \ if (!m_ReceivedData.Proc(Var)) \ { \ + m_ReceivedData.CheckValid(); \ return PARSE_INCOMPLETE; \ } \ + m_ReceivedData.CheckValid(); \ } @@ -111,7 +113,7 @@ typedef unsigned char Byte; cProtocol125::cProtocol125(cClientHandle * a_Client) : super(a_Client), - m_ReceivedData(64 KiB) + m_ReceivedData(32 KiB) { } diff --git a/source/Protocol/Protocol132.cpp b/source/Protocol/Protocol132.cpp index 8471b5909..3e28e471d 100644 --- a/source/Protocol/Protocol132.cpp +++ b/source/Protocol/Protocol132.cpp @@ -25,8 +25,10 @@ { \ if (!m_ReceivedData.Proc(Var)) \ { \ + m_ReceivedData.CheckValid(); \ return PARSE_INCOMPLETE; \ } \ + m_ReceivedData.CheckValid(); \ } diff --git a/source/Protocol/Protocol142.cpp b/source/Protocol/Protocol142.cpp index a9d49cd79..b0dec0211 100644 --- a/source/Protocol/Protocol142.cpp +++ b/source/Protocol/Protocol142.cpp @@ -25,8 +25,10 @@ { \ if (!m_ReceivedData.Proc(Var)) \ { \ + m_ReceivedData.CheckValid(); \ return PARSE_INCOMPLETE; \ } \ + m_ReceivedData.CheckValid(); \ }