From 8971e699d97b46b41dc65ac585eda0b73a1e184c Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Tue, 19 Sep 2017 23:27:49 +0200 Subject: [PATCH] common.Must2 --- common/common.go | 6 ++++ proxy/shadowsocks/server.go | 8 ++--- proxy/vmess/account.go | 20 ++++++------ proxy/vmess/encoding/auth.go | 11 ++++--- proxy/vmess/encoding/client.go | 49 ++++++++++++++-------------- proxy/vmess/encoding/commands.go | 23 ++++++------- proxy/vmess/encoding/server.go | 9 ++--- proxy/vmess/inbound/inbound.go | 2 +- proxy/vmess/outbound/command.go | 8 ++--- proxy/vmess/vmess.go | 5 +-- transport/internet/kcp/connection.go | 2 +- transport/internet/kcp/crypt.go | 5 +-- transport/internet/kcp/output.go | 29 +++++++++++++--- 13 files changed, 103 insertions(+), 74 deletions(-) diff --git a/common/common.go b/common/common.go index 448784ae0..47d47dbde 100644 --- a/common/common.go +++ b/common/common.go @@ -10,3 +10,9 @@ func Must(err error) { panic(err) } } + +func Must2(v interface{}, err error) { + if err != nil { + panic(err) + } +} diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index 4ee5733cc..62a86e4ff 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -70,7 +70,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet } } -func (v *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection, dispatcher dispatcher.Interface) error { +func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection, dispatcher dispatcher.Interface) error { udpServer := udp.NewDispatcher(dispatcher) reader := buf.NewReader(conn) @@ -81,7 +81,7 @@ func (v *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection } for _, payload := range mpayload { - request, data, err := DecodeUDPPacket(v.user, payload) + request, data, err := DecodeUDPPacket(s.user, payload) if err != nil { if source, ok := proxy.SourceFromContext(ctx); ok { log.Trace(newError("dropping invalid UDP packet from: ", source).Base(err)) @@ -91,13 +91,13 @@ func (v *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection continue } - if request.Option.Has(RequestOptionOneTimeAuth) && v.account.OneTimeAuth == Account_Disabled { + if request.Option.Has(RequestOptionOneTimeAuth) && s.account.OneTimeAuth == Account_Disabled { log.Trace(newError("client payload enables OTA but server doesn't allow it")) payload.Release() continue } - if !request.Option.Has(RequestOptionOneTimeAuth) && v.account.OneTimeAuth == Account_Enabled { + if !request.Option.Has(RequestOptionOneTimeAuth) && s.account.OneTimeAuth == Account_Enabled { log.Trace(newError("client payload disables OTA but server forces it")) payload.Release() continue diff --git a/proxy/vmess/account.go b/proxy/vmess/account.go index dac1d8889..4c00b4005 100644 --- a/proxy/vmess/account.go +++ b/proxy/vmess/account.go @@ -13,24 +13,24 @@ type InternalAccount struct { Security protocol.Security } -func (v *InternalAccount) AnyValidID() *protocol.ID { - if len(v.AlterIDs) == 0 { - return v.ID +func (a *InternalAccount) AnyValidID() *protocol.ID { + if len(a.AlterIDs) == 0 { + return a.ID } - return v.AlterIDs[dice.Roll(len(v.AlterIDs))] + return a.AlterIDs[dice.Roll(len(a.AlterIDs))] } -func (v *InternalAccount) Equals(account protocol.Account) bool { +func (a *InternalAccount) Equals(account protocol.Account) bool { vmessAccount, ok := account.(*InternalAccount) if !ok { return false } // TODO: handle AlterIds difference - return v.ID.Equals(vmessAccount.ID) + return a.ID.Equals(vmessAccount.ID) } -func (v *Account) AsAccount() (protocol.Account, error) { - id, err := uuid.ParseString(v.Id) +func (a *Account) AsAccount() (protocol.Account, error) { + id, err := uuid.ParseString(a.Id) if err != nil { log.Trace(newError("failed to parse ID").Base(err).AtError()) return nil, err @@ -38,7 +38,7 @@ func (v *Account) AsAccount() (protocol.Account, error) { protoID := protocol.NewID(id) return &InternalAccount{ ID: protoID, - AlterIDs: protocol.NewAlterIDs(protoID, uint16(v.AlterId)), - Security: v.SecuritySettings.AsSecurity(), + AlterIDs: protocol.NewAlterIDs(protoID, uint16(a.AlterId)), + Security: a.SecuritySettings.AsSecurity(), }, nil } diff --git a/proxy/vmess/encoding/auth.go b/proxy/vmess/encoding/auth.go index ecc267e4b..05477661c 100644 --- a/proxy/vmess/encoding/auth.go +++ b/proxy/vmess/encoding/auth.go @@ -4,15 +4,16 @@ import ( "crypto/md5" "hash/fnv" - "golang.org/x/crypto/sha3" - + "v2ray.com/core/common" "v2ray.com/core/common/serial" + + "golang.org/x/crypto/sha3" ) // Authenticate authenticates a byte array using Fnv hash. func Authenticate(b []byte) uint32 { fnv1hash := fnv.New32a() - fnv1hash.Write(b) + common.Must2(fnv1hash.Write(b)) return fnv1hash.Sum32() } @@ -81,7 +82,7 @@ type ShakeSizeParser struct { func NewShakeSizeParser(nonce []byte) *ShakeSizeParser { shake := sha3.NewShake128() - shake.Write(nonce) + common.Must2(shake.Write(nonce)) return &ShakeSizeParser{ shake: shake, } @@ -92,7 +93,7 @@ func (*ShakeSizeParser) SizeBytes() int { } func (s *ShakeSizeParser) next() uint16 { - s.shake.Read(s.buffer[:]) + common.Must2(s.shake.Read(s.buffer[:])) return serial.BytesToUint16(s.buffer[:]) } diff --git a/proxy/vmess/encoding/client.go b/proxy/vmess/encoding/client.go index 5d550697e..0ff3c2bb5 100644 --- a/proxy/vmess/encoding/client.go +++ b/proxy/vmess/encoding/client.go @@ -11,6 +11,7 @@ import ( "golang.org/x/crypto/chacha20poly1305" "v2ray.com/core/app/log" + "v2ray.com/core/common" "v2ray.com/core/common/buf" "v2ray.com/core/common/crypto" "v2ray.com/core/common/dice" @@ -43,7 +44,7 @@ type ClientSession struct { // NewClientSession creates a new ClientSession. func NewClientSession(idHash protocol.IDHash) *ClientSession { randomBytes := make([]byte, 33) // 16 + 16 + 1 - rand.Read(randomBytes) + common.Must2(rand.Read(randomBytes)) session := &ClientSession{} session.requestBodyKey = randomBytes[:16] @@ -58,22 +59,22 @@ func NewClientSession(idHash protocol.IDHash) *ClientSession { return session } -func (v *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) { +func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) { timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)() account, err := header.User.GetTypedAccount() if err != nil { log.Trace(newError("failed to get user account: ", err).AtError()) return } - idHash := v.idHash(account.(*vmess.InternalAccount).AnyValidID().Bytes()) - idHash.Write(timestamp.Bytes(nil)) - writer.Write(idHash.Sum(nil)) + idHash := c.idHash(account.(*vmess.InternalAccount).AnyValidID().Bytes()) + common.Must2(idHash.Write(timestamp.Bytes(nil))) + common.Must2(writer.Write(idHash.Sum(nil))) buffer := make([]byte, 0, 512) buffer = append(buffer, Version) - buffer = append(buffer, v.requestBodyIV...) - buffer = append(buffer, v.requestBodyKey...) - buffer = append(buffer, v.responseHeader, byte(header.Option)) + buffer = append(buffer, c.requestBodyIV...) + buffer = append(buffer, c.requestBodyKey...) + buffer = append(buffer, c.responseHeader, byte(header.Option)) padingLen := dice.Roll(16) if header.Security.Is(protocol.SecurityType_LEGACY) { // Disable padding in legacy mode for a smooth transition. @@ -100,29 +101,27 @@ func (v *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ if padingLen > 0 { pading := make([]byte, padingLen) - rand.Read(pading) + common.Must2(rand.Read(pading)) buffer = append(buffer, pading...) } fnv1a := fnv.New32a() - fnv1a.Write(buffer) + common.Must2(fnv1a.Write(buffer)) buffer = fnv1a.Sum(buffer) timestampHash := md5.New() - timestampHash.Write(hashTimestamp(timestamp)) + common.Must2(timestampHash.Write(hashTimestamp(timestamp))) iv := timestampHash.Sum(nil) aesStream := crypto.NewAesEncryptionStream(account.(*vmess.InternalAccount).ID.CmdKey(), iv) aesStream.XORKeyStream(buffer, buffer) - writer.Write(buffer) - - return + common.Must2(writer.Write(buffer)) } -func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, writer io.Writer) buf.Writer { +func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, writer io.Writer) buf.Writer { var sizeParser crypto.ChunkSizeEncoder = crypto.PlainChunkSizeParser{} if request.Option.Has(protocol.RequestOptionChunkMasking) { - sizeParser = NewShakeSizeParser(v.requestBodyIV) + sizeParser = NewShakeSizeParser(c.requestBodyIV) } if request.Security.Is(protocol.SecurityType_NONE) { if request.Option.Has(protocol.RequestOptionChunkStream) { @@ -141,7 +140,7 @@ func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write } if request.Security.Is(protocol.SecurityType_LEGACY) { - aesStream := crypto.NewAesEncryptionStream(v.requestBodyKey, v.requestBodyIV) + aesStream := crypto.NewAesEncryptionStream(c.requestBodyKey, c.requestBodyIV) cryptionWriter := crypto.NewCryptionWriter(aesStream, writer) if request.Option.Has(protocol.RequestOptionChunkStream) { auth := &crypto.AEADAuthenticator{ @@ -156,13 +155,13 @@ func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write } if request.Security.Is(protocol.SecurityType_AES128_GCM) { - block, _ := aes.NewCipher(v.requestBodyKey) + block, _ := aes.NewCipher(c.requestBodyKey) aead, _ := cipher.NewGCM(block) auth := &crypto.AEADAuthenticator{ AEAD: aead, NonceGenerator: &ChunkNonceGenerator{ - Nonce: append([]byte(nil), v.requestBodyIV...), + Nonce: append([]byte(nil), c.requestBodyIV...), Size: aead.NonceSize(), }, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, @@ -171,12 +170,12 @@ func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write } if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) { - aead, _ := chacha20poly1305.New(GenerateChacha20Poly1305Key(v.requestBodyKey)) + aead, _ := chacha20poly1305.New(GenerateChacha20Poly1305Key(c.requestBodyKey)) auth := &crypto.AEADAuthenticator{ AEAD: aead, NonceGenerator: &ChunkNonceGenerator{ - Nonce: append([]byte(nil), v.requestBodyIV...), + Nonce: append([]byte(nil), c.requestBodyIV...), Size: aead.NonceSize(), }, AdditionalDataGenerator: crypto.NoOpBytesGenerator{}, @@ -299,8 +298,8 @@ type ChunkNonceGenerator struct { count uint16 } -func (v *ChunkNonceGenerator) Next() []byte { - serial.Uint16ToBytes(v.count, v.Nonce[:0]) - v.count++ - return v.Nonce[:v.Size] +func (g *ChunkNonceGenerator) Next() []byte { + serial.Uint16ToBytes(g.count, g.Nonce[:0]) + g.count++ + return g.Nonce[:g.Size] } diff --git a/proxy/vmess/encoding/commands.go b/proxy/vmess/encoding/commands.go index 84b24d445..eb25139a4 100644 --- a/proxy/vmess/encoding/commands.go +++ b/proxy/vmess/encoding/commands.go @@ -3,6 +3,7 @@ package encoding import ( "io" + "v2ray.com/core/common" "v2ray.com/core/common/buf" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" @@ -45,8 +46,8 @@ func MarshalCommand(command interface{}, writer io.Writer) error { return ErrCommandTooLarge } - writer.Write([]byte{cmdID, byte(len), byte(auth >> 24), byte(auth >> 16), byte(auth >> 8), byte(auth)}) - writer.Write(buffer.Bytes()) + common.Must2(writer.Write([]byte{cmdID, byte(len), byte(auth >> 24), byte(auth >> 16), byte(auth >> 8), byte(auth)})) + common.Must2(writer.Write(buffer.Bytes())) return nil } @@ -78,7 +79,7 @@ type CommandFactory interface { type CommandSwitchAccountFactory struct { } -func (v *CommandSwitchAccountFactory) Marshal(command interface{}, writer io.Writer) error { +func (f *CommandSwitchAccountFactory) Marshal(command interface{}, writer io.Writer) error { cmd, ok := command.(*protocol.CommandSwitchAccount) if !ok { return ErrCommandTypeMismatch @@ -88,25 +89,25 @@ func (v *CommandSwitchAccountFactory) Marshal(command interface{}, writer io.Wri if cmd.Host != nil { hostStr = cmd.Host.String() } - writer.Write([]byte{byte(len(hostStr))}) + common.Must2(writer.Write([]byte{byte(len(hostStr))})) if len(hostStr) > 0 { - writer.Write([]byte(hostStr)) + common.Must2(writer.Write([]byte(hostStr))) } - writer.Write(cmd.Port.Bytes(nil)) + common.Must2(writer.Write(cmd.Port.Bytes(nil))) idBytes := cmd.ID.Bytes() - writer.Write(idBytes) + common.Must2(writer.Write(idBytes)) - writer.Write(serial.Uint16ToBytes(cmd.AlterIds, nil)) - writer.Write([]byte{byte(cmd.Level)}) + common.Must2(writer.Write(serial.Uint16ToBytes(cmd.AlterIds, nil))) + common.Must2(writer.Write([]byte{byte(cmd.Level)})) - writer.Write([]byte{cmd.ValidMin}) + common.Must2(writer.Write([]byte{cmd.ValidMin})) return nil } -func (v *CommandSwitchAccountFactory) Unmarshal(data []byte) (interface{}, error) { +func (f *CommandSwitchAccountFactory) Unmarshal(data []byte) (interface{}, error) { cmd := new(protocol.CommandSwitchAccount) if len(data) == 0 { return nil, newError("insufficient length.") diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go index 8f49a56d1..257462c59 100644 --- a/proxy/vmess/encoding/server.go +++ b/proxy/vmess/encoding/server.go @@ -11,6 +11,7 @@ import ( "time" "golang.org/x/crypto/chacha20poly1305" + "v2ray.com/core/common" "v2ray.com/core/common/buf" "v2ray.com/core/common/crypto" "v2ray.com/core/common/net" @@ -126,7 +127,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request } timestampHash := md5.New() - timestampHash.Write(hashTimestamp(timestamp)) + common.Must2(timestampHash.Write(hashTimestamp(timestamp))) iv := timestampHash.Sum(nil) account, err := user.GetTypedAccount() if err != nil { @@ -220,7 +221,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request } fnv1a := fnv.New32a() - fnv1a.Write(buffer[:bufferLen]) + common.Must2(fnv1a.Write(buffer[:bufferLen])) actualHash := fnv1a.Sum32() expectedHash := serial.BytesToUint32(buffer[bufferLen : bufferLen+4]) @@ -314,10 +315,10 @@ func (s *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, wr encryptionWriter := crypto.NewCryptionWriter(aesStream, writer) s.responseWriter = encryptionWriter - encryptionWriter.Write([]byte{s.responseHeader, byte(header.Option)}) + common.Must2(encryptionWriter.Write([]byte{s.responseHeader, byte(header.Option)})) err := MarshalCommand(header.Command, encryptionWriter) if err != nil { - encryptionWriter.Write([]byte{0x00, 0x00}) + common.Must2(encryptionWriter.Write([]byte{0x00, 0x00})) } } diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 75fc44b50..bbb03b423 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -132,7 +132,7 @@ func transferRequest(timer signal.ActivityTimer, session *encoding.ServerSession bodyReader := session.DecodeRequestBody(request, input) if err := buf.Copy(bodyReader, output, buf.UpdateActivity(timer)); err != nil { - return err + return newError("failed to transfer request").Base(err) } return nil } diff --git a/proxy/vmess/outbound/command.go b/proxy/vmess/outbound/command.go index 918980604..19316acd3 100644 --- a/proxy/vmess/outbound/command.go +++ b/proxy/vmess/outbound/command.go @@ -9,7 +9,7 @@ import ( "v2ray.com/core/proxy/vmess" ) -func (v *Handler) handleSwitchAccount(cmd *protocol.CommandSwitchAccount) { +func (h *Handler) handleSwitchAccount(cmd *protocol.CommandSwitchAccount) { account := &vmess.Account{ Id: cmd.ID.String(), AlterId: uint32(cmd.AlterIds), @@ -25,16 +25,16 @@ func (v *Handler) handleSwitchAccount(cmd *protocol.CommandSwitchAccount) { } dest := net.TCPDestination(cmd.Host, cmd.Port) until := time.Now().Add(time.Duration(cmd.ValidMin) * time.Minute) - v.serverList.AddServer(protocol.NewServerSpec(dest, protocol.BeforeTime(until), user)) + h.serverList.AddServer(protocol.NewServerSpec(dest, protocol.BeforeTime(until), user)) } -func (v *Handler) handleCommand(dest net.Destination, cmd protocol.ResponseCommand) { +func (h *Handler) handleCommand(dest net.Destination, cmd protocol.ResponseCommand) { switch typedCommand := cmd.(type) { case *protocol.CommandSwitchAccount: if typedCommand.Host == nil { typedCommand.Host = dest.Address } - v.handleSwitchAccount(typedCommand) + h.handleSwitchAccount(typedCommand) default: } } diff --git a/proxy/vmess/vmess.go b/proxy/vmess/vmess.go index 37d94e859..50928046e 100644 --- a/proxy/vmess/vmess.go +++ b/proxy/vmess/vmess.go @@ -12,6 +12,7 @@ import ( "sync" "time" + "v2ray.com/core/common" "v2ray.com/core/common/protocol" ) @@ -60,11 +61,11 @@ func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, idx in var hashValueRemoval [16]byte idHash := v.hasher(entry.id.Bytes()) for entry.lastSec <= nowSec { - idHash.Write(entry.lastSec.Bytes(nil)) + common.Must2(idHash.Write(entry.lastSec.Bytes(nil))) idHash.Sum(hashValue[:0]) idHash.Reset() - idHash.Write(entry.lastSecRemoval.Bytes(nil)) + common.Must2(idHash.Write(entry.lastSecRemoval.Bytes(nil))) idHash.Sum(hashValueRemoval[:0]) idHash.Reset() diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index 5dd15705d..f20ab562e 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -214,7 +214,7 @@ func NewConnection(conv uint16, sysConn SystemConnection, config *Config) *Conne dataInput: make(chan bool, 1), dataOutput: make(chan bool, 1), Config: config, - output: NewSegmentWriter(sysConn), + output: NewRetryableWriter(NewSegmentWriter(sysConn)), mss: config.GetMTUValue() - uint32(sysConn.Overhead()) - DataSegmentOverhead, roundTrip: &RoundTripInfo{ rto: 100, diff --git a/transport/internet/kcp/crypt.go b/transport/internet/kcp/crypt.go index 2925c6174..4bedbe31c 100644 --- a/transport/internet/kcp/crypt.go +++ b/transport/internet/kcp/crypt.go @@ -4,6 +4,7 @@ import ( "crypto/cipher" "hash/fnv" + "v2ray.com/core/common" "v2ray.com/core/common/serial" ) @@ -32,7 +33,7 @@ func (a *SimpleAuthenticator) Seal(dst, nonce, plain, extra []byte) []byte { dst = append(dst, plain...) fnvHash := fnv.New32a() - fnvHash.Write(dst[4:]) + common.Must2(fnvHash.Write(dst[4:])) fnvHash.Sum(dst[:0]) len := len(dst) @@ -61,7 +62,7 @@ func (a *SimpleAuthenticator) Open(dst, nonce, cipherText, extra []byte) ([]byte } fnvHash := fnv.New32a() - fnvHash.Write(dst[4:]) + common.Must2(fnvHash.Write(dst[4:])) if serial.BytesToUint32(dst[:4]) != fnvHash.Sum32() { return nil, newError("invalid auth") } diff --git a/transport/internet/kcp/output.go b/transport/internet/kcp/output.go index b2a2deef6..9e00215dd 100644 --- a/transport/internet/kcp/output.go +++ b/transport/internet/kcp/output.go @@ -4,6 +4,9 @@ import ( "io" "sync" + "v2ray.com/core/common/retry" + + "v2ray.com/core/common" "v2ray.com/core/common/buf" ) @@ -24,11 +27,27 @@ func NewSegmentWriter(writer io.Writer) SegmentWriter { } } -func (v *SimpleSegmentWriter) Write(seg Segment) error { - v.Lock() - defer v.Unlock() +func (w *SimpleSegmentWriter) Write(seg Segment) error { + w.Lock() + defer w.Unlock() - v.buffer.Reset(seg.Bytes()) - _, err := v.writer.Write(v.buffer.Bytes()) + common.Must(w.buffer.Reset(seg.Bytes())) + _, err := w.writer.Write(w.buffer.Bytes()) return err } + +type RetryableWriter struct { + writer SegmentWriter +} + +func NewRetryableWriter(writer SegmentWriter) SegmentWriter { + return &RetryableWriter{ + writer: writer, + } +} + +func (w *RetryableWriter) Write(seg Segment) error { + return retry.Timed(5, 100).On(func() error { + return w.writer.Write(seg) + }) +}