1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-06-29 10:45:22 +00:00

Merge pull request #144 from Loyalsoldier/refine-code

Refine code
This commit is contained in:
RPRX 2020-08-30 15:59:23 +00:00 committed by GitHub
commit 9e0859ee49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 42 additions and 51 deletions

View File

@ -12,7 +12,6 @@ import (
"golang.org/x/net/dns/dnsmessage" "golang.org/x/net/dns/dnsmessage"
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
v2net "v2ray.com/core/common/net"
) )
func Test_parseResponse(t *testing.T) { func Test_parseResponse(t *testing.T) {
@ -52,7 +51,7 @@ func Test_parseResponse(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{"empty", {"empty",
&IPRecord{0, []v2net.Address(nil), time.Time{}, dnsmessage.RCodeSuccess}, &IPRecord{0, []net.Address(nil), time.Time{}, dnsmessage.RCodeSuccess},
false, false,
}, },
{"error", {"error",
@ -60,12 +59,12 @@ func Test_parseResponse(t *testing.T) {
true, true,
}, },
{"a record", {"a record",
&IPRecord{1, []v2net.Address{v2net.ParseAddress("8.8.8.8"), v2net.ParseAddress("8.8.4.4")}, &IPRecord{1, []net.Address{net.ParseAddress("8.8.8.8"), net.ParseAddress("8.8.4.4")},
time.Time{}, dnsmessage.RCodeSuccess}, time.Time{}, dnsmessage.RCodeSuccess},
false, false,
}, },
{"aaaa record", {"aaaa record",
&IPRecord{2, []v2net.Address{v2net.ParseAddress("2001::123:8888"), v2net.ParseAddress("2001::123:8844")}, time.Time{}, dnsmessage.RCodeSuccess}, &IPRecord{2, []net.Address{net.ParseAddress("2001::123:8888"), net.ParseAddress("2001::123:8844")}, time.Time{}, dnsmessage.RCodeSuccess},
false, false,
}, },
} }

View File

@ -200,7 +200,7 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (internet.Conn
return h.getStatCouterConnection(conn), err return h.getStatCouterConnection(conn), err
} }
func (h *Handler) getStatCouterConnection(conn internet.Connection) (internet.Connection) { func (h *Handler) getStatCouterConnection(conn internet.Connection) internet.Connection {
if h.uplinkCounter != nil || h.downlinkCounter != nil { if h.uplinkCounter != nil || h.downlinkCounter != nil {
return &internet.StatCouterConnection{ return &internet.StatCouterConnection{
Connection: conn, Connection: conn,

View File

@ -57,7 +57,7 @@ func TestOutboundWithStatCounter(t *testing.T) {
serial.ToTypedMessage(&policy.Config{ serial.ToTypedMessage(&policy.Config{
System: &policy.SystemPolicy{ System: &policy.SystemPolicy{
Stats: &policy.SystemPolicy_Stats{ Stats: &policy.SystemPolicy_Stats{
OutboundUplink: true, OutboundUplink: true,
OutboundDownlink: true, OutboundDownlink: true,
}, },
}, },

View File

@ -83,7 +83,7 @@ func (p *IncrementalWorkerPicker) findAvailable() int {
return -1 return -1
} }
func (p *IncrementalWorkerPicker) pickInternal() (*ClientWorker, error, bool) { func (p *IncrementalWorkerPicker) pickInternal() (*ClientWorker, bool, error) {
p.access.Lock() p.access.Lock()
defer p.access.Unlock() defer p.access.Unlock()
@ -93,14 +93,14 @@ func (p *IncrementalWorkerPicker) pickInternal() (*ClientWorker, error, bool) {
if n > 1 && idx != n-1 { if n > 1 && idx != n-1 {
p.workers[n-1], p.workers[idx] = p.workers[idx], p.workers[n-1] p.workers[n-1], p.workers[idx] = p.workers[idx], p.workers[n-1]
} }
return p.workers[idx], nil, false return p.workers[idx], false, nil
} }
p.cleanup() p.cleanup()
worker, err := p.Factory.Create() worker, err := p.Factory.Create()
if err != nil { if err != nil {
return nil, err, false return nil, false, err
} }
p.workers = append(p.workers, worker) p.workers = append(p.workers, worker)
@ -111,11 +111,11 @@ func (p *IncrementalWorkerPicker) pickInternal() (*ClientWorker, error, bool) {
} }
} }
return worker, nil, true return worker, true, nil
} }
func (p *IncrementalWorkerPicker) PickAvailable() (*ClientWorker, error) { func (p *IncrementalWorkerPicker) PickAvailable() (*ClientWorker, error) {
worker, err, start := p.pickInternal() worker, start, err := p.pickInternal()
if start { if start {
common.Must(p.cleanupTask.Start()) common.Must(p.cleanupTask.Start())
} }

View File

@ -2,8 +2,8 @@ package sidh
import ( import (
"errors" "errors"
. "v2ray.com/core/external/github.com/cloudflare/sidh/internal/isogeny"
"io" "io"
. "v2ray.com/core/external/github.com/cloudflare/sidh/internal/isogeny"
) )
// I keep it bool in order to be able to apply logical NOT // I keep it bool in order to be able to apply logical NOT

View File

@ -3,8 +3,8 @@ package quic
import ( import (
"sync" "sync"
"v2ray.com/core/external/github.com/lucas-clemente/quic-go/internal/protocol"
"v2ray.com/core/common/bytespool" "v2ray.com/core/common/bytespool"
"v2ray.com/core/external/github.com/lucas-clemente/quic-go/internal/protocol"
) )
type packetBuffer struct { type packetBuffer struct {

View File

@ -22,8 +22,8 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"v2ray.com/core/external/github.com/cloudflare/sidh/sidh"
"golang.org/x/crypto/curve25519" "golang.org/x/crypto/curve25519"
"v2ray.com/core/external/github.com/cloudflare/sidh/sidh"
) )
// numSessionTickets is the number of different session tickets the // numSessionTickets is the number of different session tickets the

View File

@ -83,7 +83,7 @@ func generateRandomBytes(random []byte, connType [4]byte) {
continue continue
} }
if 0x00000000 == (uint32(random[7])<<24)|(uint32(random[6])<<16)|(uint32(random[5])<<8)|uint32(random[4]) { if (uint32(random[7])<<24)|(uint32(random[6])<<16)|(uint32(random[5])<<8)|uint32(random[4]) == 0x00000000 {
continue continue
} }

View File

@ -38,7 +38,7 @@ var addrParser = protocol.NewAddressParser(
func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.RequestHeader, buf.Reader, error) { func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.RequestHeader, buf.Reader, error) {
account := user.Account.(*MemoryAccount) account := user.Account.(*MemoryAccount)
hashkdf := hmac.New(func()hash.Hash{return sha256.New()}, []byte("SSBSKDF")) hashkdf := hmac.New(func() hash.Hash { return sha256.New() }, []byte("SSBSKDF"))
hashkdf.Write(account.Key) hashkdf.Write(account.Key)
behaviorSeed := crc32.ChecksumIEEE(hashkdf.Sum(nil)) behaviorSeed := crc32.ChecksumIEEE(hashkdf.Sum(nil))
@ -50,7 +50,6 @@ func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.Requ
DrainSize := BaseDrainSize + 16 + 38 + RandDrainRolled DrainSize := BaseDrainSize + 16 + 38 + RandDrainRolled
readSizeRemain := DrainSize readSizeRemain := DrainSize
buffer := buf.New() buffer := buf.New()
defer buffer.Release() defer buffer.Release()
@ -59,7 +58,7 @@ func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.Requ
if ivLen > 0 { if ivLen > 0 {
if _, err := buffer.ReadFullFrom(reader, ivLen); err != nil { if _, err := buffer.ReadFullFrom(reader, ivLen); err != nil {
readSizeRemain -= int(buffer.Len()) readSizeRemain -= int(buffer.Len())
DrainConnN(reader,readSizeRemain) DrainConnN(reader, readSizeRemain)
return nil, nil, newError("failed to read IV").Base(err) return nil, nil, newError("failed to read IV").Base(err)
} }
@ -69,7 +68,7 @@ func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.Requ
r, err := account.Cipher.NewDecryptionReader(account.Key, iv, reader) r, err := account.Cipher.NewDecryptionReader(account.Key, iv, reader)
if err != nil { if err != nil {
readSizeRemain -= int(buffer.Len()) readSizeRemain -= int(buffer.Len())
DrainConnN(reader,readSizeRemain) DrainConnN(reader, readSizeRemain)
return nil, nil, newError("failed to initialize decoding stream").Base(err).AtError() return nil, nil, newError("failed to initialize decoding stream").Base(err).AtError()
} }
br := &buf.BufferedReader{Reader: r} br := &buf.BufferedReader{Reader: r}
@ -87,7 +86,7 @@ func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.Requ
addr, port, err := addrParser.ReadAddressPort(buffer, br) addr, port, err := addrParser.ReadAddressPort(buffer, br)
if err != nil { if err != nil {
readSizeRemain -= int(buffer.Len()) readSizeRemain -= int(buffer.Len())
DrainConnN(reader,readSizeRemain) DrainConnN(reader, readSizeRemain)
return nil, nil, newError("failed to read address").Base(err) return nil, nil, newError("failed to read address").Base(err)
} }
@ -101,13 +100,13 @@ func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.Requ
if request.Option.Has(RequestOptionOneTimeAuth) && account.OneTimeAuth == Account_Disabled { if request.Option.Has(RequestOptionOneTimeAuth) && account.OneTimeAuth == Account_Disabled {
readSizeRemain -= int(buffer.Len()) readSizeRemain -= int(buffer.Len())
DrainConnN(reader,readSizeRemain) DrainConnN(reader, readSizeRemain)
return nil, nil, newError("rejecting connection with OTA enabled, while server disables OTA") return nil, nil, newError("rejecting connection with OTA enabled, while server disables OTA")
} }
if !request.Option.Has(RequestOptionOneTimeAuth) && account.OneTimeAuth == Account_Enabled { if !request.Option.Has(RequestOptionOneTimeAuth) && account.OneTimeAuth == Account_Enabled {
readSizeRemain -= int(buffer.Len()) readSizeRemain -= int(buffer.Len())
DrainConnN(reader,readSizeRemain) DrainConnN(reader, readSizeRemain)
return nil, nil, newError("rejecting connection with OTA disabled, while server enables OTA") return nil, nil, newError("rejecting connection with OTA disabled, while server enables OTA")
} }
} }
@ -119,20 +118,20 @@ func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.Requ
_, err := buffer.ReadFullFrom(br, AuthSize) _, err := buffer.ReadFullFrom(br, AuthSize)
if err != nil { if err != nil {
readSizeRemain -= int(buffer.Len()) readSizeRemain -= int(buffer.Len())
DrainConnN(reader,readSizeRemain) DrainConnN(reader, readSizeRemain)
return nil, nil, newError("Failed to read OTA").Base(err) return nil, nil, newError("Failed to read OTA").Base(err)
} }
if !bytes.Equal(actualAuth, buffer.BytesFrom(-AuthSize)) { if !bytes.Equal(actualAuth, buffer.BytesFrom(-AuthSize)) {
readSizeRemain -= int(buffer.Len()) readSizeRemain -= int(buffer.Len())
DrainConnN(reader,readSizeRemain) DrainConnN(reader, readSizeRemain)
return nil, nil, newError("invalid OTA") return nil, nil, newError("invalid OTA")
} }
} }
if request.Address == nil { if request.Address == nil {
readSizeRemain -= int(buffer.Len()) readSizeRemain -= int(buffer.Len())
DrainConnN(reader,readSizeRemain) DrainConnN(reader, readSizeRemain)
return nil, nil, newError("invalid remote address.") return nil, nil, newError("invalid remote address.")
} }

View File

@ -8,6 +8,7 @@ import (
"encoding/binary" "encoding/binary"
"io" "io"
"time" "time"
"v2ray.com/core/common" "v2ray.com/core/common"
) )
@ -21,8 +22,7 @@ func SealVMessAEADHeader(key [16]byte, data []byte) []byte {
aeadPayloadLengthSerializeBuffer := bytes.NewBuffer(nil) aeadPayloadLengthSerializeBuffer := bytes.NewBuffer(nil)
var headerPayloadDataLen uint16 headerPayloadDataLen := uint16(len(data))
headerPayloadDataLen = uint16(len(data))
common.Must(binary.Write(aeadPayloadLengthSerializeBuffer, binary.BigEndian, headerPayloadDataLen)) common.Must(binary.Write(aeadPayloadLengthSerializeBuffer, binary.BigEndian, headerPayloadDataLen))

View File

@ -3,9 +3,10 @@ package aead
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"github.com/stretchr/testify/assert"
"io" "io"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestOpenVMessAEADHeader(t *testing.T) { func TestOpenVMessAEADHeader(t *testing.T) {

View File

@ -12,19 +12,18 @@ import (
"io/ioutil" "io/ioutil"
"sync" "sync"
"time" "time"
"v2ray.com/core/common/dice"
vmessaead "v2ray.com/core/proxy/vmess/aead"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/bitmask" "v2ray.com/core/common/bitmask"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/crypto" "v2ray.com/core/common/crypto"
"v2ray.com/core/common/dice"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
"v2ray.com/core/common/task" "v2ray.com/core/common/task"
"v2ray.com/core/proxy/vmess" "v2ray.com/core/proxy/vmess"
vmessaead "v2ray.com/core/proxy/vmess/aead"
) )
type sessionId struct { type sessionId struct {
@ -170,7 +169,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
var fixedSizeAuthID [16]byte var fixedSizeAuthID [16]byte
copy(fixedSizeAuthID[:], buffer.Bytes()) copy(fixedSizeAuthID[:], buffer.Bytes())
if foundAEAD == true { if foundAEAD {
vmessAccount = user.Account.(*vmess.MemoryAccount) vmessAccount = user.Account.(*vmess.MemoryAccount)
var fixedSizeCmdKey [16]byte var fixedSizeCmdKey [16]byte
copy(fixedSizeCmdKey[:], vmessAccount.ID.CmdKey()) copy(fixedSizeCmdKey[:], vmessAccount.ID.CmdKey())
@ -405,8 +404,7 @@ func (s *ServerSession) EncodeResponseHeader(header *protocol.ResponseHeader, wr
aeadResponseHeaderLengthEncryptionBuffer := bytes.NewBuffer(nil) aeadResponseHeaderLengthEncryptionBuffer := bytes.NewBuffer(nil)
var decryptedResponseHeaderLengthBinaryDeserializeBuffer uint16 decryptedResponseHeaderLengthBinaryDeserializeBuffer := uint16(aeadEncryptedHeaderBuffer.Len())
decryptedResponseHeaderLengthBinaryDeserializeBuffer = uint16(aeadEncryptedHeaderBuffer.Len())
common.Must(binary.Write(aeadResponseHeaderLengthEncryptionBuffer, binary.BigEndian, decryptedResponseHeaderLengthBinaryDeserializeBuffer)) common.Must(binary.Write(aeadResponseHeaderLengthEncryptionBuffer, binary.BigEndian, decryptedResponseHeaderLengthBinaryDeserializeBuffer))

View File

@ -11,13 +11,13 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"v2ray.com/core/common/dice"
"v2ray.com/core/proxy/vmess/aead"
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/dice"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
"v2ray.com/core/common/serial" "v2ray.com/core/common/serial"
"v2ray.com/core/common/task" "v2ray.com/core/common/task"
"v2ray.com/core/proxy/vmess/aead"
) )
const ( const (
@ -141,8 +141,8 @@ func (v *TimedUserValidator) Add(u *protocol.MemoryUser) error {
v.generateNewHashes(protocol.Timestamp(nowSec), uu) v.generateNewHashes(protocol.Timestamp(nowSec), uu)
account := uu.user.Account.(*MemoryAccount) account := uu.user.Account.(*MemoryAccount)
if v.behaviorFused == false { if !v.behaviorFused {
hashkdf := hmac.New(func()hash.Hash{return sha256.New()}, []byte("VMESSBSKDF")) hashkdf := hmac.New(func() hash.Hash { return sha256.New() }, []byte("VMESSBSKDF"))
hashkdf.Write(account.ID.Bytes()) hashkdf.Write(account.ID.Bytes())
v.behaviorSeed = crc64.Update(v.behaviorSeed, crc64.MakeTable(crc64.ECMA), hashkdf.Sum(nil)) v.behaviorSeed = crc64.Update(v.behaviorSeed, crc64.MakeTable(crc64.ECMA), hashkdf.Sum(nil))
} }

View File

@ -928,8 +928,8 @@ func TestVMessKCPLarge(t *testing.T) {
t.Error(err) t.Error(err)
} }
defer func(){ defer func() {
<-time.After(5*time.Second) <-time.After(5 * time.Second)
CloseAllServers(servers) CloseAllServers(servers)
}() }()
} }
@ -1178,8 +1178,8 @@ func TestVMessGCMMuxUDP(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
} }
defer func(){ defer func() {
<-time.After(5*time.Second) <-time.After(5 * time.Second)
CloseAllServers(servers) CloseAllServers(servers)
}() }()
} }

View File

@ -11,7 +11,7 @@ import (
"strings" "strings"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet"

View File

@ -126,7 +126,7 @@ func (h *HeaderReader) Read(reader io.Reader) (*buf.Buffer, error) {
} }
} }
if hasThisUri == false { if !hasThisUri {
return nil, ErrHeaderMisMatch return nil, ErrHeaderMisMatch
} }

View File

@ -34,7 +34,6 @@ func TestReaderWriter(t *testing.T) {
t.Error("unknown error ", err) t.Error("unknown error ", err)
} }
_ = buffer _ = buffer
return
/* /*
if buffer.String() != "efg" { if buffer.String() != "efg" {
t.Error("buffer: ", buffer.String()) t.Error("buffer: ", buffer.String())
@ -256,7 +255,6 @@ func TestConnectionInvPath(t *testing.T) {
break break
} }
} }
return
} }
func TestConnectionInvReq(t *testing.T) { func TestConnectionInvReq(t *testing.T) {
@ -315,5 +313,4 @@ func TestConnectionInvReq(t *testing.T) {
if !strings.HasPrefix(string(l), "HTTP/1.1 400 Bad Request") { if !strings.HasPrefix(string(l), "HTTP/1.1 400 Bad Request") {
t.Error("Resp to non http conn", string(l)) t.Error("Resp to non http conn", string(l))
} }
return
} }

View File

@ -121,10 +121,7 @@ func (sw *SendingWindow) Flush(current uint32, rto uint32, maxInFlightSize uint3
segment.transmit++ segment.transmit++
sw.writer.Write(segment) sw.writer.Write(segment)
inFlightSize++ inFlightSize++
if inFlightSize >= maxInFlightSize { return inFlightSize < maxInFlightSize
return false
}
return true
}) })
if sw.onPacketLoss != nil && inFlightSize > 0 && sw.totalInFlightSize != 0 { if sw.onPacketLoss != nil && inFlightSize > 0 && sw.totalInFlightSize != 0 {