1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2025-01-21 16:56:27 -05:00

refine kcp header and security

This commit is contained in:
Darien Raymond 2016-12-08 16:27:41 +01:00
parent 0ad629ca31
commit 0917866f38
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
39 changed files with 530 additions and 393 deletions

8
all.go
View File

@ -21,8 +21,8 @@ import (
_ "v2ray.com/core/transport/internet/udp"
_ "v2ray.com/core/transport/internet/ws"
_ "v2ray.com/core/transport/internet/authenticators/http"
_ "v2ray.com/core/transport/internet/authenticators/noop"
_ "v2ray.com/core/transport/internet/authenticators/srtp"
_ "v2ray.com/core/transport/internet/authenticators/utp"
_ "v2ray.com/core/transport/internet/headers/http"
_ "v2ray.com/core/transport/internet/headers/noop"
_ "v2ray.com/core/transport/internet/headers/srtp"
_ "v2ray.com/core/transport/internet/headers/utp"
)

View File

@ -3,10 +3,10 @@ package conf
import (
"v2ray.com/core/common/errors"
"v2ray.com/core/common/loader"
"v2ray.com/core/transport/internet/authenticators/http"
"v2ray.com/core/transport/internet/authenticators/noop"
"v2ray.com/core/transport/internet/authenticators/srtp"
"v2ray.com/core/transport/internet/authenticators/utp"
"v2ray.com/core/transport/internet/headers/http"
"v2ray.com/core/transport/internet/headers/noop"
"v2ray.com/core/transport/internet/headers/srtp"
"v2ray.com/core/transport/internet/headers/utp"
)
type NoOpAuthenticator struct{}

View File

@ -1,28 +0,0 @@
package internet_test
import (
"testing"
"v2ray.com/core/common/loader"
"v2ray.com/core/testing/assert"
. "v2ray.com/core/transport/internet"
"v2ray.com/core/transport/internet/authenticators/noop"
"v2ray.com/core/transport/internet/authenticators/srtp"
"v2ray.com/core/transport/internet/authenticators/utp"
)
func TestAllAuthenticatorLoadable(t *testing.T) {
assert := assert.On(t)
noopAuth, err := CreateAuthenticator(loader.GetType(new(noop.Config)), nil)
assert.Error(err).IsNil()
assert.Int(noopAuth.Overhead()).Equals(0)
srtp, err := CreateAuthenticator(loader.GetType(new(srtp.Config)), nil)
assert.Error(err).IsNil()
assert.Int(srtp.Overhead()).Equals(4)
utp, err := CreateAuthenticator(loader.GetType(new(utp.Config)), nil)
assert.Error(err).IsNil()
assert.Int(utp.Overhead()).Equals(4)
}

View File

@ -1,8 +0,0 @@
syntax = "proto3";
package v2ray.core.transport.internet.authenticators.noop;
option go_package = "noop";
option java_package = "com.v2ray.core.transport.internet.authenticators.noop";
option java_outer_classname = "ConfigProto";
message Config {}

View File

@ -1,46 +0,0 @@
package noop
import (
"net"
"v2ray.com/core/common/alloc"
"v2ray.com/core/common/loader"
"v2ray.com/core/transport/internet"
)
type NoOpAuthenticator struct{}
func (v NoOpAuthenticator) Overhead() int {
return 0
}
func (v NoOpAuthenticator) Open(payload *alloc.Buffer) bool {
return true
}
func (v NoOpAuthenticator) Seal(payload *alloc.Buffer) {}
type NoOpAuthenticatorFactory struct{}
func (v NoOpAuthenticatorFactory) Create(config interface{}) internet.Authenticator {
return NoOpAuthenticator{}
}
type NoOpConnectionAuthenticator struct{}
func (NoOpConnectionAuthenticator) Client(conn net.Conn) net.Conn {
return conn
}
func (NoOpConnectionAuthenticator) Server(conn net.Conn) net.Conn {
return conn
}
type NoOpConnectionAuthenticatorFactory struct{}
func (NoOpConnectionAuthenticatorFactory) Create(config interface{}) internet.ConnectionAuthenticator {
return NoOpConnectionAuthenticator{}
}
func init() {
internet.RegisterAuthenticator(loader.GetType(new(Config)), NoOpAuthenticatorFactory{})
internet.RegisterConnectionAuthenticator(loader.GetType(new(Config)), NoOpConnectionAuthenticatorFactory{})
}

View File

@ -1,44 +0,0 @@
package srtp
import (
"math/rand"
"v2ray.com/core/common/alloc"
"v2ray.com/core/common/loader"
"v2ray.com/core/common/serial"
"v2ray.com/core/transport/internet"
)
type SRTP struct {
header uint16
number uint16
}
func (v *SRTP) Overhead() int {
return 4
}
func (v *SRTP) Open(payload *alloc.Buffer) bool {
payload.SliceFrom(v.Overhead())
return true
}
func (v *SRTP) Seal(payload *alloc.Buffer) {
v.number++
payload.PrependFunc(2, serial.WriteUint16(v.number))
payload.PrependFunc(2, serial.WriteUint16(v.header))
}
type SRTPFactory struct {
}
func (v SRTPFactory) Create(rawSettings interface{}) internet.Authenticator {
return &SRTP{
header: 0xB5E8,
number: uint16(rand.Intn(65536)),
}
}
func init() {
internet.RegisterAuthenticator(loader.GetType(new(Config)), SRTPFactory{})
}

View File

@ -1,10 +0,0 @@
syntax = "proto3";
package v2ray.core.transport.internet.authenticators.utp;
option go_package = "utp";
option java_package = "com.v2ray.core.transport.internet.authenticators.utp";
option java_outer_classname = "ConfigProto";
message Config {
uint32 version = 1;
}

View File

@ -1,44 +0,0 @@
package utp
import (
"math/rand"
"v2ray.com/core/common/alloc"
"v2ray.com/core/common/loader"
"v2ray.com/core/common/serial"
"v2ray.com/core/transport/internet"
)
type UTP struct {
header byte
extension byte
connectionId uint16
}
func (v *UTP) Overhead() int {
return 4
}
func (v *UTP) Open(payload *alloc.Buffer) bool {
payload.SliceFrom(v.Overhead())
return true
}
func (v *UTP) Seal(payload *alloc.Buffer) {
payload.PrependFunc(2, serial.WriteUint16(v.connectionId))
payload.PrependBytes(v.header, v.extension)
}
type UTPFactory struct{}
func (v UTPFactory) Create(rawSettings interface{}) internet.Authenticator {
return &UTP{
header: 1,
extension: 0,
connectionId: uint16(rand.Intn(65536)),
}
}
func init() {
internet.RegisterAuthenticator(loader.GetType(new(Config)), UTPFactory{})
}

View File

@ -9,20 +9,6 @@ import (
. "v2ray.com/core/transport/internet"
)
func TestDialDomain(t *testing.T) {
assert := assert.On(t)
server := &tcp.Server{}
dest, err := server.Start()
assert.Error(err).IsNil()
defer server.Close()
conn, err := DialToDest(nil, v2net.TCPDestination(v2net.DomainAddress("local.v2ray.com"), dest.Port))
assert.Error(err).IsNil()
assert.String(conn.RemoteAddr().String()).Equals("127.0.0.1:" + dest.Port.String())
conn.Close()
}
func TestDialWithLocalAddr(t *testing.T) {
assert := assert.On(t)

View File

@ -0,0 +1,32 @@
package internet
import "v2ray.com/core/common"
type PacketHeader interface {
Size() int
Write([]byte) int
}
type PacketHeaderFactory interface {
Create(interface{}) PacketHeader
}
var (
headerCache = make(map[string]PacketHeaderFactory)
)
func RegisterPacketHeader(name string, factory PacketHeaderFactory) error {
if _, found := headerCache[name]; found {
return common.ErrDuplicatedName
}
headerCache[name] = factory
return nil
}
func CreatePacketHeader(name string, config interface{}) (PacketHeader, error) {
factory, found := headerCache[name]
if !found {
return nil, common.ErrObjectNotFound
}
return factory.Create(config), nil
}

View File

@ -0,0 +1,28 @@
package internet_test
import (
"testing"
"v2ray.com/core/common/loader"
"v2ray.com/core/testing/assert"
. "v2ray.com/core/transport/internet"
"v2ray.com/core/transport/internet/headers/noop"
"v2ray.com/core/transport/internet/headers/srtp"
"v2ray.com/core/transport/internet/headers/utp"
)
func TestAllHeadersLoadable(t *testing.T) {
assert := assert.On(t)
noopAuth, err := CreatePacketHeader(loader.GetType(new(noop.Config)), nil)
assert.Error(err).IsNil()
assert.Int(noopAuth.Size()).Equals(0)
srtp, err := CreatePacketHeader(loader.GetType(new(srtp.Config)), nil)
assert.Error(err).IsNil()
assert.Int(srtp.Size()).Equals(4)
utp, err := CreatePacketHeader(loader.GetType(new(utp.Config)), nil)
assert.Error(err).IsNil()
assert.Int(utp.Size()).Equals(4)
}

View File

@ -1,8 +1,8 @@
syntax = "proto3";
package v2ray.core.transport.internet.authenticators.http;
package v2ray.core.transport.internet.headers.http;
option go_package = "http";
option java_package = "com.v2ray.core.transport.internet.authenticators.http";
option java_package = "com.v2ray.core.transport.internet.headers.http";
option java_outer_classname = "ConfigProto";
message Header {

View File

@ -6,7 +6,7 @@ import (
"v2ray.com/core/common/alloc"
"v2ray.com/core/common/serial"
"v2ray.com/core/testing/assert"
. "v2ray.com/core/transport/internet/authenticators/http"
. "v2ray.com/core/transport/internet/headers/http"
)
func TestReaderWriter(t *testing.T) {

View File

@ -0,0 +1,8 @@
syntax = "proto3";
package v2ray.core.transport.internet.headers.noop;
option go_package = "noop";
option java_package = "com.v2ray.core.transport.internet.headers.noop";
option java_outer_classname = "ConfigProto";
message Config {}

View File

@ -0,0 +1,44 @@
package noop
import (
"net"
"v2ray.com/core/common/loader"
"v2ray.com/core/transport/internet"
)
type NoOpHeader struct{}
func (v NoOpHeader) Size() int {
return 0
}
func (v NoOpHeader) Write([]byte) int {
return 0
}
type NoOpHeaderFactory struct{}
func (v NoOpHeaderFactory) Create(config interface{}) internet.PacketHeader {
return NoOpHeader{}
}
type NoOpConnectionHeader struct{}
func (NoOpConnectionHeader) Client(conn net.Conn) net.Conn {
return conn
}
func (NoOpConnectionHeader) Server(conn net.Conn) net.Conn {
return conn
}
type NoOpConnectionHeaderFactory struct{}
func (NoOpConnectionHeaderFactory) Create(config interface{}) internet.ConnectionAuthenticator {
return NoOpConnectionHeader{}
}
func init() {
internet.RegisterPacketHeader(loader.GetType(new(Config)), NoOpHeaderFactory{})
internet.RegisterConnectionAuthenticator(loader.GetType(new(Config)), NoOpConnectionHeaderFactory{})
}

View File

@ -1,8 +1,8 @@
syntax = "proto3";
package v2ray.core.transport.internet.authenticators.srtp;
package v2ray.core.transport.internet.headers.srtp;
option go_package = "srtp";
option java_package = "com.v2ray.core.transport.internet.authenticators.srtp";
option java_package = "com.v2ray.core.transport.internet.headers.srtp";
option java_outer_classname = "ConfigProto";
message Config {

View File

@ -0,0 +1,39 @@
package srtp
import (
"math/rand"
"v2ray.com/core/common/loader"
"v2ray.com/core/common/serial"
"v2ray.com/core/transport/internet"
)
type SRTP struct {
header uint16
number uint16
}
func (v *SRTP) Size() int {
return 4
}
func (v *SRTP) Write(b []byte) int {
v.number++
b = serial.Uint16ToBytes(v.number, b[:0])
b = serial.Uint16ToBytes(v.number, b)
return 4
}
type SRTPFactory struct {
}
func (v SRTPFactory) Create(rawSettings interface{}) internet.PacketHeader {
return &SRTP{
header: 0xB5E8,
number: uint16(rand.Intn(65536)),
}
}
func init() {
internet.RegisterPacketHeader(loader.GetType(new(Config)), SRTPFactory{})
}

View File

@ -5,19 +5,18 @@ import (
"v2ray.com/core/common/alloc"
"v2ray.com/core/testing/assert"
. "v2ray.com/core/transport/internet/authenticators/srtp"
. "v2ray.com/core/transport/internet/headers/srtp"
)
func TestSRTPOpenSeal(t *testing.T) {
func TestSRTPWrite(t *testing.T) {
assert := assert.On(t)
content := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g'}
srtp := SRTP{}
payload := alloc.NewLocalBuffer(2048)
payload.AppendFunc(srtp.Write)
payload.Append(content)
srtp := SRTP{}
srtp.Seal(payload)
assert.Int(payload.Len()).GreaterThan(len(content))
assert.Bool(srtp.Open(payload)).IsTrue()
assert.Bytes(content).Equals(payload.Bytes())
assert.Int(payload.Len()).Equals(len(content) + srtp.Size())
}

View File

@ -0,0 +1,10 @@
syntax = "proto3";
package v2ray.core.transport.internet.headers.utp;
option go_package = "utp";
option java_package = "com.v2ray.core.transport.internet.headers.utp";
option java_outer_classname = "ConfigProto";
message Config {
uint32 version = 1;
}

View File

@ -0,0 +1,39 @@
package utp
import (
"math/rand"
"v2ray.com/core/common/loader"
"v2ray.com/core/common/serial"
"v2ray.com/core/transport/internet"
)
type UTP struct {
header byte
extension byte
connectionId uint16
}
func (v *UTP) Size() int {
return 4
}
func (v *UTP) Write(b []byte) int {
b = serial.Uint16ToBytes(v.connectionId, b[:0])
b = append(b, v.header, v.extension)
return 4
}
type UTPFactory struct{}
func (v UTPFactory) Create(rawSettings interface{}) internet.PacketHeader {
return &UTP{
header: 1,
extension: 0,
connectionId: uint16(rand.Intn(65536)),
}
}
func init() {
internet.RegisterPacketHeader(loader.GetType(new(Config)), UTPFactory{})
}

View File

@ -5,19 +5,18 @@ import (
"v2ray.com/core/common/alloc"
"v2ray.com/core/testing/assert"
. "v2ray.com/core/transport/internet/authenticators/utp"
. "v2ray.com/core/transport/internet/headers/utp"
)
func TestUTPOpenSeal(t *testing.T) {
func TestUTPWrite(t *testing.T) {
assert := assert.On(t)
content := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g'}
utp := UTP{}
payload := alloc.NewLocalBuffer(2048)
payload.AppendFunc(utp.Write)
payload.Append(content)
utp := UTP{}
utp.Seal(payload)
assert.Int(payload.Len()).GreaterThan(len(content))
assert.Bool(utp.Open(payload)).IsTrue()
assert.Bytes(content).Equals(payload.Bytes())
assert.Int(payload.Len()).Equals(len(content) + utp.Size())
}

View File

@ -1,6 +1,8 @@
package kcp
import (
"crypto/cipher"
v2net "v2ray.com/core/common/net"
"v2ray.com/core/transport/internet"
)
@ -47,21 +49,20 @@ func (v *ReadBuffer) GetSize() uint32 {
return v.Size
}
func (v *Config) GetAuthenticator() (internet.Authenticator, error) {
auth := NewSimpleAuthenticator()
func (v *Config) GetSecurity() (cipher.AEAD, error) {
return NewSimpleAuthenticator(), nil
}
func (v *Config) GetPackerHeader() (internet.PacketHeader, error) {
if v.HeaderConfig != nil {
rawConfig, err := v.HeaderConfig.GetInstance()
if err != nil {
return nil, err
}
header, err := internet.CreateAuthenticator(v.HeaderConfig.Type, rawConfig)
if err != nil {
return nil, err
}
auth = internet.NewAuthenticatorChain(header, auth)
return internet.CreatePacketHeader(v.HeaderConfig.Type, rawConfig)
}
return auth, nil
return nil, nil
}
func (v *Config) GetSendingInFlightSize() uint32 {

View File

@ -6,6 +6,7 @@ import (
"sync"
"sync/atomic"
"time"
"v2ray.com/core/common/errors"
"v2ray.com/core/common/log"
"v2ray.com/core/common/predicate"
@ -161,7 +162,8 @@ func (v *Updater) Run() {
type SystemConnection interface {
net.Conn
Id() internal.ConnectionId
Reset(internet.Authenticator, func([]byte))
Reset(func([]Segment))
Overhead() int
}
// Connection is a KCP connection over UDP.
@ -197,32 +199,25 @@ type Connection struct {
}
// NewConnection create a new KCP connection between local and remote.
func NewConnection(conv uint16, sysConn SystemConnection, recycler internal.ConnectionRecyler, block internet.Authenticator, config *Config) *Connection {
func NewConnection(conv uint16, sysConn SystemConnection, recycler internal.ConnectionRecyler, config *Config) *Connection {
log.Info("KCP|Connection: creating connection ", conv)
authWriter := &AuthenticationWriter{
Authenticator: block,
Writer: sysConn,
Config: config,
}
conn := &Connection{
conv: conv,
conn: sysConn,
connRecycler: recycler,
block: block,
since: nowMillisec(),
dataInputCond: sync.NewCond(new(sync.Mutex)),
dataOutputCond: sync.NewCond(new(sync.Mutex)),
Config: config,
output: NewSegmentWriter(authWriter),
mss: authWriter.Mtu() - DataSegmentOverhead,
output: NewSegmentWriter(sysConn, config.GetMtu().GetValue()-uint32(sysConn.Overhead())),
mss: config.GetMtu().GetValue() - uint32(sysConn.Overhead()) - DataSegmentOverhead,
roundTrip: &RoundTripInfo{
rto: 100,
minRtt: config.Tti.GetValue(),
},
}
sysConn.Reset(block, conn.Input)
sysConn.Reset(conn.Input)
conn.receivingWorker = NewReceivingWorker(conn)
conn.sendingWorker = NewSendingWorker(conn)
@ -480,16 +475,11 @@ func (v *Connection) OnPeerClosed() {
}
// Input when you received a low level packet (eg. UDP packet), call it
func (v *Connection) Input(data []byte) {
func (v *Connection) Input(segments []Segment) {
current := v.Elapsed()
atomic.StoreUint32(&v.lastIncomingTime, current)
var seg Segment
for {
seg, data = ReadSegment(data)
if seg == nil {
break
}
for _, seg := range segments {
if seg.Conversation() != v.conv {
return
}
@ -507,7 +497,7 @@ func (v *Connection) Input(data []byte) {
v.dataUpdater.WakeUp()
case *CmdOnlySegment:
v.HandleOption(seg.Option)
if seg.Command == CommandTerminate {
if seg.Command() == CommandTerminate {
state := v.State()
if state == StateActive ||
state == StatePeerClosed {
@ -577,7 +567,7 @@ func (v *Connection) State() State {
func (v *Connection) Ping(current uint32, cmd Command) {
seg := NewCmdOnlySegment()
seg.Conv = v.conv
seg.Command = cmd
seg.Cmd = cmd
seg.ReceivinNext = v.receivingWorker.nextNumber
seg.SendingNext = v.sendingWorker.firstUnacknowledged
seg.PeerRTO = v.roundTrip.Timeout()

View File

@ -4,14 +4,18 @@ import (
"net"
"testing"
"time"
"v2ray.com/core/testing/assert"
"v2ray.com/core/transport/internet"
"v2ray.com/core/transport/internet/internal"
. "v2ray.com/core/transport/internet/kcp"
)
type NoOpConn struct{}
func (o *NoOpConn) Overhead() int {
return 0
}
func (o *NoOpConn) Write(b []byte) (int, error) {
return len(b), nil
}
@ -48,7 +52,7 @@ func (o *NoOpConn) Id() internal.ConnectionId {
return internal.ConnectionId{}
}
func (o *NoOpConn) Reset(auth internet.Authenticator, input func([]byte)) {}
func (o *NoOpConn) Reset(input func([]Segment)) {}
type NoOpRecycler struct{}
@ -57,7 +61,7 @@ func (o *NoOpRecycler) Put(internal.ConnectionId, net.Conn) {}
func TestConnectionReadTimeout(t *testing.T) {
assert := assert.On(t)
conn := NewConnection(1, &NoOpConn{}, &NoOpRecycler{}, NewSimpleAuthenticator(), &Config{})
conn := NewConnection(1, &NoOpConn{}, &NoOpRecycler{}, &Config{})
conn.SetReadDeadline(time.Now().Add(time.Second))
b := make([]byte, 1024)

View File

@ -1,63 +1,74 @@
package kcp
import (
"crypto/cipher"
"errors"
"hash/fnv"
"v2ray.com/core/common/alloc"
"v2ray.com/core/common/serial"
"v2ray.com/core/transport/internet"
)
var (
errInvalidAuth = errors.New("Invalid auth.")
)
type SimpleAuthenticator struct{}
func NewSimpleAuthenticator() internet.Authenticator {
func NewSimpleAuthenticator() cipher.AEAD {
return &SimpleAuthenticator{}
}
func (v *SimpleAuthenticator) NonceSize() int {
return 0
}
func (v *SimpleAuthenticator) Overhead() int {
return 6
}
func (v *SimpleAuthenticator) Seal(buffer *alloc.Buffer) {
buffer.PrependFunc(2, serial.WriteUint16(uint16(buffer.Len())))
fnvHash := fnv.New32a()
fnvHash.Write(buffer.Bytes())
buffer.PrependFunc(4, serial.WriteHash(fnvHash))
func (v *SimpleAuthenticator) Seal(dst, nonce, plain, extra []byte) []byte {
dst = append(dst, 0, 0, 0, 0)
dst = serial.Uint16ToBytes(uint16(len(plain)), dst)
dst = append(dst, plain...)
len := buffer.Len()
fnvHash := fnv.New32a()
fnvHash.Write(dst[4:])
fnvHash.Sum(dst[:0])
len := len(dst)
xtra := 4 - len%4
if xtra != 0 {
buffer.Slice(0, len+xtra)
if xtra != 4 {
dst = append(dst, make([]byte, xtra)...)
}
xorfwd(buffer.Bytes())
if xtra != 0 {
buffer.Slice(0, len)
xorfwd(dst)
if xtra != 4 {
dst = dst[:len]
}
return dst
}
func (v *SimpleAuthenticator) Open(buffer *alloc.Buffer) bool {
len := buffer.Len()
xtra := 4 - len%4
if xtra != 0 {
buffer.Slice(0, len+xtra)
func (v *SimpleAuthenticator) Open(dst, nonce, cipherText, extra []byte) ([]byte, error) {
dst = append(dst, cipherText...)
dstLen := len(dst)
xtra := 4 - dstLen%4
if xtra != 4 {
dst = append(dst, make([]byte, xtra)...)
}
xorbkd(buffer.Bytes())
if xtra != 0 {
buffer.Slice(0, len)
xorbkd(dst)
if xtra != 4 {
dst = dst[:dstLen]
}
fnvHash := fnv.New32a()
fnvHash.Write(buffer.BytesFrom(4))
if serial.BytesToUint32(buffer.BytesTo(4)) != fnvHash.Sum32() {
return false
fnvHash.Write(dst[4:])
if serial.BytesToUint32(dst[:4]) != fnvHash.Sum32() {
return nil, errInvalidAuth
}
length := serial.BytesToUint16(buffer.BytesRange(4, 6))
if buffer.Len()-6 != int(length) {
return false
length := serial.BytesToUint16(dst[4:6])
if len(dst)-6 != int(length) {
return nil, errInvalidAuth
}
buffer.SliceFrom(6)
return true
return dst[6:], nil
}

View File

@ -1,10 +1,8 @@
package kcp_test
import (
"crypto/rand"
"testing"
"v2ray.com/core/common/alloc"
"v2ray.com/core/testing/assert"
. "v2ray.com/core/transport/internet/kcp"
)
@ -12,38 +10,27 @@ import (
func TestSimpleAuthenticator(t *testing.T) {
assert := assert.On(t)
buffer := alloc.NewLocalBuffer(512)
buffer.AppendBytes('a', 'b', 'c', 'd', 'e', 'f', 'g')
cache := make([]byte, 512)
payload := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g'}
auth := NewSimpleAuthenticator()
auth.Seal(buffer)
assert.Bool(auth.Open(buffer)).IsTrue()
assert.Bytes(buffer.Bytes()).Equals([]byte{'a', 'b', 'c', 'd', 'e', 'f', 'g'})
b := auth.Seal(cache[:0], nil, payload, nil)
c, err := auth.Open(cache[:0], nil, b, nil)
assert.Error(err).IsNil()
assert.Bytes(c).Equals(payload)
}
func TestSimpleAuthenticator2(t *testing.T) {
assert := assert.On(t)
buffer := alloc.NewLocalBuffer(512)
buffer.AppendBytes('1', '2')
cache := make([]byte, 512)
payload := []byte{'a', 'b'}
auth := NewSimpleAuthenticator()
auth.Seal(buffer)
assert.Bool(auth.Open(buffer)).IsTrue()
assert.Bytes(buffer.Bytes()).Equals([]byte{'1', '2'})
}
func BenchmarkSimpleAuthenticator(b *testing.B) {
buffer := alloc.NewLocalBuffer(2048)
buffer.FillFullFrom(rand.Reader, 1024)
auth := NewSimpleAuthenticator()
b.SetBytes(int64(buffer.Len()))
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth.Seal(buffer)
auth.Open(buffer)
}
b := auth.Seal(cache[:0], nil, payload, nil)
c, err := auth.Open(cache[:0], nil, b, nil)
assert.Error(err).IsNil()
assert.Bytes(c).Equals(payload)
}

View File

@ -5,8 +5,12 @@ import (
"net"
"sync"
"sync/atomic"
"crypto/cipher"
"v2ray.com/core/common/alloc"
"v2ray.com/core/common/dice"
"v2ray.com/core/common/errors"
"v2ray.com/core/common/log"
v2net "v2ray.com/core/common/net"
"v2ray.com/core/transport/internet"
@ -20,11 +24,32 @@ var (
)
type ClientConnection struct {
sync.Mutex
sync.RWMutex
net.Conn
id internal.ConnectionId
input func([]byte)
auth internet.Authenticator
id internal.ConnectionId
input func([]Segment)
reader PacketReader
writer PacketWriter
}
func (o *ClientConnection) Overhead() int {
o.RLock()
defer o.RUnlock()
if o.writer == nil {
return 0
}
return o.writer.Overhead()
}
func (o *ClientConnection) Write(b []byte) (int, error) {
o.RLock()
defer o.RUnlock()
if o.writer == nil {
return len(b), nil
}
return o.writer.Write(b)
}
func (o *ClientConnection) Read([]byte) (int, error) {
@ -39,10 +64,26 @@ func (o *ClientConnection) Close() error {
return o.Conn.Close()
}
func (o *ClientConnection) Reset(auth internet.Authenticator, inputCallback func([]byte)) {
func (o *ClientConnection) Reset(inputCallback func([]Segment)) {
o.Lock()
o.input = inputCallback
o.auth = auth
o.Unlock()
}
func (o *ClientConnection) ResetSecurity(header internet.PacketHeader, security cipher.AEAD) {
o.Lock()
if o.reader == nil {
o.reader = new(KCPPacketReader)
}
o.reader.(*KCPPacketReader).Header = header
o.reader.(*KCPPacketReader).Security = security
if o.writer == nil {
o.writer = new(KCPPacketWriter)
}
o.writer.(*KCPPacketWriter).Header = header
o.writer.(*KCPPacketWriter).Security = security
o.writer.(*KCPPacketWriter).Writer = o.Conn
o.Unlock()
}
@ -57,12 +98,14 @@ func (o *ClientConnection) Run() {
payload.Release()
return
}
o.Lock()
if o.input != nil && o.auth.Open(payload) {
o.input(payload.Bytes())
o.RLock()
if o.input != nil {
segments := o.reader.Read(payload.Bytes())
if len(segments) > 0 {
o.input(segments)
}
}
o.Unlock()
payload.Reset()
o.RUnlock()
}
}
@ -93,13 +136,18 @@ func DialKCP(src v2net.Address, dest v2net.Destination, options internet.DialerO
}
kcpSettings := networkSettings.(*Config)
cpip, err := kcpSettings.GetAuthenticator()
clientConn := conn.(*ClientConnection)
header, err := kcpSettings.GetPackerHeader()
if err != nil {
log.Error("KCP|Dialer: Failed to create authenticator: ", err)
return nil, err
return nil, errors.Base(err).Message("KCP|Dialer: Failed to create packet header.")
}
security, err := kcpSettings.GetSecurity()
if err != nil {
return nil, errors.Base(err).Message("KCP|Dialer: Failed to create security.")
}
clientConn.ResetSecurity(header, security)
conv := uint16(atomic.AddUint32(&globalConv, 1))
session := NewConnection(conv, conn.(*ClientConnection), globalPool, cpip, kcpSettings)
session := NewConnection(conv, clientConn, globalPool, kcpSettings)
var iConn internet.Connection
iConn = session

View File

@ -0,0 +1,92 @@
package kcp
import (
"crypto/cipher"
"crypto/rand"
"io"
"v2ray.com/core/transport/internet"
)
type PacketReader interface {
Read([]byte) []Segment
}
type PacketWriter interface {
Overhead() int
io.Writer
}
type KCPPacketReader struct {
Security cipher.AEAD
Header internet.PacketHeader
}
func (v *KCPPacketReader) Read(b []byte) []Segment {
if v.Header != nil {
b = b[v.Header.Size():]
}
if v.Security != nil {
nonceSize := v.Security.NonceSize()
out, err := v.Security.Open(b[nonceSize:nonceSize], b[:nonceSize], b[nonceSize:], nil)
if err != nil {
return nil
}
b = out
}
var result []Segment
for len(b) > 0 {
seg, x := ReadSegment(b)
if seg == nil {
break
}
result = append(result, seg)
b = x
}
return result
}
type KCPPacketWriter struct {
Header internet.PacketHeader
Security cipher.AEAD
Writer io.Writer
buffer [32 * 1024]byte
}
func (v *KCPPacketWriter) Overhead() int {
overhead := 0
if v.Header != nil {
overhead += v.Header.Size()
}
if v.Security != nil {
overhead += v.Security.Overhead()
}
return overhead
}
func (v *KCPPacketWriter) Write(b []byte) (int, error) {
x := v.buffer[:]
size := 0
if v.Header != nil {
nBytes := v.Header.Write(x)
size += nBytes
x = x[nBytes:]
}
if v.Security != nil {
nonceSize := v.Security.NonceSize()
var nonce []byte
if nonceSize > 0 {
nonce = x[:nonceSize]
rand.Read(nonce)
x = x[nonceSize:]
}
x = v.Security.Seal(x[:0], nonce, b, nil)
size += nonceSize + len(x)
} else {
size += copy(x, b)
}
_, err := v.Writer.Write(v.buffer[:size])
return len(b), err
}

View File

@ -0,0 +1 @@
package kcp_test

View File

@ -2,14 +2,17 @@ package kcp
import (
"crypto/tls"
"io"
"net"
"sync"
"time"
"crypto/cipher"
"v2ray.com/core/common/alloc"
"v2ray.com/core/common/errors"
"v2ray.com/core/common/log"
v2net "v2ray.com/core/common/net"
"v2ray.com/core/common/serial"
"v2ray.com/core/proxy"
"v2ray.com/core/transport/internet"
"v2ray.com/core/transport/internet/internal"
@ -25,11 +28,14 @@ type ConnectionId struct {
type ServerConnection struct {
id internal.ConnectionId
writer *Writer
local net.Addr
remote net.Addr
auth internet.Authenticator
input func([]byte)
writer PacketWriter
closer io.Closer
}
func (o *ServerConnection) Overhead() int {
return o.writer.Overhead()
}
func (o *ServerConnection) Read([]byte) (int, error) {
@ -41,20 +47,10 @@ func (o *ServerConnection) Write(b []byte) (int, error) {
}
func (o *ServerConnection) Close() error {
return o.writer.Close()
return o.closer.Close()
}
func (o *ServerConnection) Reset(auth internet.Authenticator, input func([]byte)) {
o.auth = auth
o.input = input
}
func (o *ServerConnection) Input(b *alloc.Buffer) {
defer b.Release()
if o.auth.Open(b) {
o.input(b.Bytes())
}
func (o *ServerConnection) Reset(input func([]Segment)) {
}
func (o *ServerConnection) LocalAddr() net.Addr {
@ -85,12 +81,14 @@ func (o *ServerConnection) Id() internal.ConnectionId {
type Listener struct {
sync.Mutex
running bool
authenticator internet.Authenticator
sessions map[ConnectionId]*Connection
awaitingConns chan *Connection
hub *udp.UDPHub
tlsConfig *tls.Config
config *Config
reader PacketReader
header internet.PacketHeader
security cipher.AEAD
}
func NewListener(address v2net.Address, port v2net.Port, options internet.ListenOptions) (*Listener, error) {
@ -102,12 +100,21 @@ func NewListener(address v2net.Address, port v2net.Port, options internet.Listen
kcpSettings := networkSettings.(*Config)
kcpSettings.ConnectionReuse = &ConnectionReuse{Enable: false}
auth, err := kcpSettings.GetAuthenticator()
header, err := kcpSettings.GetPackerHeader()
if err != nil {
return nil, err
return nil, errors.Base(err).Message("KCP|Listener: Failed to create packet header.")
}
security, err := kcpSettings.GetSecurity()
if err != nil {
return nil, errors.Base(err).Message("KCP|Listener: Failed to create security.")
}
l := &Listener{
authenticator: auth,
header: header,
security: security,
reader: &KCPPacketReader{
Header: header,
Security: security,
},
sessions: make(map[ConnectionId]*Connection),
awaitingConns: make(chan *Connection, 64),
running: true,
@ -138,10 +145,12 @@ func (v *Listener) OnReceive(payload *alloc.Buffer, session *proxy.SessionInfo)
src := session.Source
if valid := v.authenticator.Open(payload); !valid {
segments := v.reader.Read(payload.Bytes())
if len(segments) == 0 {
log.Info("KCP|Listener: discarding invalid payload from ", src)
return
}
if !v.running {
return
}
@ -153,8 +162,9 @@ func (v *Listener) OnReceive(payload *alloc.Buffer, session *proxy.SessionInfo)
if payload.Len() < 4 {
return
}
conv := serial.BytesToUint16(payload.BytesTo(2))
cmd := Command(payload.Byte(2))
conv := segments[0].Conversation()
cmd := segments[0].Command()
id := ConnectionId{
Remote: src.Address,
Port: src.Port,
@ -177,17 +187,18 @@ func (v *Listener) OnReceive(payload *alloc.Buffer, session *proxy.SessionInfo)
Port: int(src.Port),
}
localAddr := v.hub.Addr()
auth, err := v.config.GetAuthenticator()
if err != nil {
log.Error("KCP|Listener: Failed to create authenticator: ", err)
}
sConn := &ServerConnection{
id: internal.NewConnectionId(v2net.LocalHostIP, src),
local: localAddr,
remote: remoteAddr,
writer: writer,
writer: &KCPPacketWriter{
Header: v.header,
Writer: writer,
Security: v.security,
},
closer: writer,
}
conn = NewConnection(conv, sConn, v, auth, v.config)
conn = NewConnection(conv, sConn, v, v.config)
select {
case v.awaitingConns <- conn:
case <-time.After(time.Second * 5):
@ -196,7 +207,7 @@ func (v *Listener) OnReceive(payload *alloc.Buffer, session *proxy.SessionInfo)
}
v.sessions[id] = conn
}
conn.Input(payload.Bytes())
conn.Input(segments)
}
func (v *Listener) Remove(id ConnectionId) {

View File

@ -5,8 +5,6 @@ import (
"sync"
"v2ray.com/core/common/alloc"
v2io "v2ray.com/core/common/io"
"v2ray.com/core/transport/internet"
)
type SegmentWriter interface {
@ -17,13 +15,14 @@ type BufferedSegmentWriter struct {
sync.Mutex
mtu uint32
buffer *alloc.Buffer
writer v2io.Writer
writer io.Writer
}
func NewSegmentWriter(writer *AuthenticationWriter) *BufferedSegmentWriter {
func NewSegmentWriter(writer io.Writer, mtu uint32) *BufferedSegmentWriter {
return &BufferedSegmentWriter{
mtu: writer.Mtu(),
mtu: mtu,
writer: writer,
buffer: alloc.NewSmallBuffer(),
}
}
@ -36,45 +35,21 @@ func (v *BufferedSegmentWriter) Write(seg Segment) {
v.FlushWithoutLock()
}
if v.buffer == nil {
v.buffer = alloc.NewSmallBuffer()
}
v.buffer.AppendFunc(seg.Bytes())
}
func (v *BufferedSegmentWriter) FlushWithoutLock() {
v.writer.Write(v.buffer)
v.buffer = nil
v.writer.Write(v.buffer.Bytes())
v.buffer.Clear()
}
func (v *BufferedSegmentWriter) Flush() {
v.Lock()
defer v.Unlock()
if v.buffer.Len() == 0 {
if v.buffer.IsEmpty() {
return
}
v.FlushWithoutLock()
}
type AuthenticationWriter struct {
Authenticator internet.Authenticator
Writer io.Writer
Config *Config
}
func (v *AuthenticationWriter) Write(payload *alloc.Buffer) error {
defer payload.Release()
v.Authenticator.Seal(payload)
_, err := v.Writer.Write(payload.Bytes())
return err
}
func (v *AuthenticationWriter) Release() {}
func (v *AuthenticationWriter) Mtu() uint32 {
return v.Config.Mtu.GetValue() - uint32(v.Authenticator.Overhead())
}

View File

@ -24,6 +24,7 @@ const (
type Segment interface {
common.Releasable
Conversation() uint16
Command() Command
ByteSize() int
Bytes() alloc.BytesWriter
}
@ -52,6 +53,10 @@ func (v *DataSegment) Conversation() uint16 {
return v.Conv
}
func (v *DataSegment) Command() Command {
return CommandData
}
func (v *DataSegment) SetData(b []byte) {
if v.Data == nil {
v.Data = alloc.NewSmallBuffer()
@ -104,6 +109,10 @@ func (v *AckSegment) Conversation() uint16 {
return v.Conv
}
func (v *AckSegment) Command() Command {
return CommandACK
}
func (v *AckSegment) PutTimestamp(timestamp uint32) {
if timestamp-v.Timestamp < 0x7FFFFFFF {
v.Timestamp = timestamp
@ -144,7 +153,7 @@ func (v *AckSegment) Release() {
type CmdOnlySegment struct {
Conv uint16
Command Command
Cmd Command
Option SegmentOption
SendingNext uint32
ReceivinNext uint32
@ -159,6 +168,10 @@ func (v *CmdOnlySegment) Conversation() uint16 {
return v.Conv
}
func (v *CmdOnlySegment) Command() Command {
return v.Cmd
}
func (v *CmdOnlySegment) ByteSize() int {
return 2 + 1 + 1 + 4 + 4 + 4
}
@ -166,7 +179,7 @@ func (v *CmdOnlySegment) ByteSize() int {
func (v *CmdOnlySegment) Bytes() alloc.BytesWriter {
return func(b []byte) int {
b = serial.Uint16ToBytes(v.Conv, b[:0])
b = append(b, byte(v.Command), byte(v.Option))
b = append(b, byte(v.Cmd), byte(v.Option))
b = serial.Uint32ToBytes(v.SendingNext, b)
b = serial.Uint32ToBytes(v.ReceivinNext, b)
b = serial.Uint32ToBytes(v.PeerRTO, b)
@ -250,7 +263,7 @@ func ReadSegment(buf []byte) (Segment, []byte) {
seg := NewCmdOnlySegment()
seg.Conv = conv
seg.Command = cmd
seg.Cmd = cmd
seg.Option = opt
if len(buf) < 12 {

View File

@ -79,7 +79,7 @@ func TestCmdSegment(t *testing.T) {
seg := &CmdOnlySegment{
Conv: 1,
Command: CommandPing,
Cmd: CommandPing,
Option: SegmentOptionClose,
SendingNext: 11,
ReceivinNext: 13,
@ -95,7 +95,7 @@ func TestCmdSegment(t *testing.T) {
iseg, _ := ReadSegment(bytes)
seg2 := iseg.(*CmdOnlySegment)
assert.Uint16(seg2.Conv).Equals(seg.Conv)
assert.Byte(byte(seg2.Command)).Equals(byte(seg.Command))
assert.Byte(byte(seg2.Command())).Equals(byte(seg.Command()))
assert.Byte(byte(seg2.Option)).Equals(byte(seg.Option))
assert.Uint32(seg2.SendingNext).Equals(seg.SendingNext)
assert.Uint32(seg2.ReceivinNext).Equals(seg.ReceivinNext)