diff --git a/proxy/vmess/protocol/vmess.go b/proxy/vmess/protocol/vmess.go index 84fe1bc4d..cb8545295 100644 --- a/proxy/vmess/protocol/vmess.go +++ b/proxy/vmess/protocol/vmess.go @@ -34,9 +34,9 @@ const ( type VMessRequest struct { Version byte UserId user.ID - RequestIV [16]byte - RequestKey [16]byte - ResponseHeader [4]byte + RequestIV []byte + RequestKey []byte + ResponseHeader []byte Command byte Address v2net.Address } @@ -102,9 +102,9 @@ func (r *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) { return nil, errors.NewProtocolVersionError(int(request.Version)) } - copy(request.RequestIV[:], buffer[1:17]) // 16 bytes - copy(request.RequestKey[:], buffer[17:33]) // 16 bytes - copy(request.ResponseHeader[:], buffer[33:37]) // 4 bytes + request.RequestIV = buffer[1:17] // 16 bytes + request.RequestKey = buffer[17:33] // 16 bytes + request.ResponseHeader = buffer[33:37] // 4 bytes request.Command = buffer[37] port := binary.BigEndian.Uint16(buffer[38:40]) @@ -169,9 +169,9 @@ func (request *VMessRequest) ToBytes(idHash user.CounterHash, randomRangeInt64 u encryptionBegin := len(buffer) buffer = append(buffer, request.Version) - buffer = append(buffer, request.RequestIV[:]...) - buffer = append(buffer, request.RequestKey[:]...) - buffer = append(buffer, request.ResponseHeader[:]...) + buffer = append(buffer, request.RequestIV...) + buffer = append(buffer, request.RequestKey...) + buffer = append(buffer, request.ResponseHeader...) buffer = append(buffer, request.Command) buffer = append(buffer, request.Address.PortBytes()...) diff --git a/proxy/vmess/protocol/vmess_test.go b/proxy/vmess/protocol/vmess_test.go index b89604c51..c27d48836 100644 --- a/proxy/vmess/protocol/vmess_test.go +++ b/proxy/vmess/protocol/vmess_test.go @@ -26,20 +26,12 @@ func TestVMessSerialization(t *testing.T) { request.Version = byte(0x01) request.UserId = userId - _, err = rand.Read(request.RequestIV[:]) - if err != nil { - t.Fatal(err) - } - - _, err = rand.Read(request.RequestKey[:]) - if err != nil { - t.Fatal(err) - } - - _, err = rand.Read(request.ResponseHeader[:]) - if err != nil { - t.Fatal(err) - } + randBytes := make([]byte, 36) + _, err = rand.Read(randBytes) + assert.Error(err).IsNil() + request.RequestIV = randBytes[:16] + request.RequestKey = randBytes[16:32] + request.ResponseHeader = randBytes[32:] request.Command = byte(0x01) request.Address = v2net.DomainAddress("v2ray.com", 80) @@ -61,9 +53,9 @@ func TestVMessSerialization(t *testing.T) { assert.Byte(actualRequest.Version).Named("Version").Equals(byte(0x01)) assert.String(actualRequest.UserId.String).Named("UserId").Equals(request.UserId.String) - assert.Bytes(actualRequest.RequestIV[:]).Named("RequestIV").Equals(request.RequestIV[:]) - assert.Bytes(actualRequest.RequestKey[:]).Named("RequestKey").Equals(request.RequestKey[:]) - assert.Bytes(actualRequest.ResponseHeader[:]).Named("ResponseHeader").Equals(request.ResponseHeader[:]) + assert.Bytes(actualRequest.RequestIV).Named("RequestIV").Equals(request.RequestIV[:]) + assert.Bytes(actualRequest.RequestKey).Named("RequestKey").Equals(request.RequestKey[:]) + assert.Bytes(actualRequest.ResponseHeader).Named("ResponseHeader").Equals(request.ResponseHeader[:]) assert.Byte(actualRequest.Command).Named("Command").Equals(request.Command) assert.String(actualRequest.Address.String()).Named("Address").Equals(request.Address.String()) } @@ -77,9 +69,11 @@ func BenchmarkVMessRequestWriting(b *testing.B) { request.Version = byte(0x01) request.UserId = userId - rand.Read(request.RequestIV[:]) - rand.Read(request.RequestKey[:]) - rand.Read(request.ResponseHeader[:]) + randBytes := make([]byte, 36) + rand.Read(randBytes) + request.RequestIV = randBytes[:16] + request.RequestKey = randBytes[16:32] + request.ResponseHeader = randBytes[32:] request.Command = byte(0x01) request.Address = v2net.DomainAddress("v2ray.com", 80) diff --git a/proxy/vmess/vmessout.go b/proxy/vmess/vmessout.go index abac9c621..8ad1ab84a 100644 --- a/proxy/vmess/vmessout.go +++ b/proxy/vmess/vmessout.go @@ -80,9 +80,12 @@ func (handler *VMessOutboundHandler) Dispatch(firstPacket v2net.Packet, ray core Command: command, Address: firstPacket.Destination().Address(), } - rand.Read(request.RequestIV[:]) - rand.Read(request.RequestKey[:]) - rand.Read(request.ResponseHeader[:]) + + buffer := make([]byte, 36) // 16 + 16 + 4 + rand.Read(buffer) + request.RequestIV = buffer[:16] + request.RequestKey = buffer[16:32] + request.ResponseHeader = buffer[32:] go startCommunicate(request, vNextAddress, ray, firstPacket) return nil