1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-12-21 17:46:58 -05:00

per inbound session history

This commit is contained in:
Darien Raymond 2017-02-12 16:53:23 +01:00
parent 10d26f2d7f
commit ec95caa946
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
3 changed files with 28 additions and 21 deletions

View File

@ -45,10 +45,12 @@ func TestRequestSerialization(t *testing.T) {
buffer2.Append(buffer.Bytes())
ctx, cancel := context.WithCancel(context.Background())
sessionHistory := NewSessionHistory(ctx)
userValidator := vmess.NewTimedUserValidator(ctx, protocol.DefaultIDHash)
userValidator.Add(user)
server := NewServerSession(userValidator)
server := NewServerSession(userValidator, sessionHistory)
actualRequest, err := server.DecodeRequestHeader(buffer)
assert.Error(err).IsNil()

View File

@ -1,6 +1,7 @@
package encoding
import (
"context"
"crypto/aes"
"crypto/cipher"
"crypto/md5"
@ -25,26 +26,26 @@ type sessionId struct {
nonce [16]byte
}
type sessionHistory struct {
type SessionHistory struct {
sync.RWMutex
cache map[sessionId]time.Time
}
func newSessionHistory() *sessionHistory {
h := &sessionHistory{
func NewSessionHistory(ctx context.Context) *SessionHistory {
h := &SessionHistory{
cache: make(map[sessionId]time.Time, 128),
}
go h.run()
go h.run(ctx)
return h
}
func (h *sessionHistory) Add(session sessionId) {
func (h *SessionHistory) add(session sessionId) {
h.Lock()
h.cache[session] = time.Now().Add(time.Minute * 3)
h.Unlock()
}
func (h *sessionHistory) Has(session sessionId) bool {
func (h *SessionHistory) has(session sessionId) bool {
h.RLock()
defer h.RUnlock()
@ -54,9 +55,13 @@ func (h *sessionHistory) Has(session sessionId) bool {
return false
}
func (h *sessionHistory) run() {
func (h *SessionHistory) run(ctx context.Context) {
for {
time.Sleep(time.Second * 30)
select {
case <-ctx.Done():
return
case <-time.After(time.Second * 30):
}
session2Remove := make([]sessionId, 0, 16)
now := time.Now()
h.Lock()
@ -72,12 +77,9 @@ func (h *sessionHistory) run() {
}
}
var (
globalSessionHistory = newSessionHistory()
)
type ServerSession struct {
userValidator protocol.UserValidator
sessionHistory *SessionHistory
requestBodyKey []byte
requestBodyIV []byte
responseBodyKey []byte
@ -88,9 +90,10 @@ type ServerSession struct {
// NewServerSession creates a new ServerSession, using the given UserValidator.
// The ServerSession instance doesn't take ownership of the validator.
func NewServerSession(validator protocol.UserValidator) *ServerSession {
func NewServerSession(validator protocol.UserValidator, sessionHistory *SessionHistory) *ServerSession {
return &ServerSession{
userValidator: validator,
sessionHistory: sessionHistory,
}
}
@ -140,10 +143,10 @@ func (v *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
copy(sid.user[:], vmessAccount.ID.Bytes())
copy(sid.key[:], v.requestBodyKey)
copy(sid.nonce[:], v.requestBodyIV)
if globalSessionHistory.Has(sid) {
if v.sessionHistory.has(sid) {
return nil, errors.New("VMess|Server: Duplicated session id. Possibly under reply attack.")
}
globalSessionHistory.Add(sid)
v.sessionHistory.add(sid)
v.responseHeader = buffer[33] // 1 byte
request.Option = protocol.RequestOption(buffer[34]) // 1 byte

View File

@ -78,6 +78,7 @@ type VMessInboundHandler struct {
clients protocol.UserValidator
usersByEmail *userByEmail
detours *DetourConfig
sessionHistory *encoding.SessionHistory
}
func New(ctx context.Context, config *Config) (*VMessInboundHandler, error) {
@ -95,6 +96,7 @@ func New(ctx context.Context, config *Config) (*VMessInboundHandler, error) {
clients: allowedClients,
detours: config.Detour,
usersByEmail: NewUserByEmail(config.User, config.GetDefaultValue()),
sessionHistory: encoding.NewSessionHistory(ctx),
}
space.OnInitialize(func() error {
@ -171,7 +173,7 @@ func (v *VMessInboundHandler) Process(ctx context.Context, network net.Network,
connection.SetReadDeadline(time.Now().Add(time.Second * 8))
reader := bufio.NewReader(connection)
session := encoding.NewServerSession(v.clients)
session := encoding.NewServerSession(v.clients, v.sessionHistory)
request, err := session.DecodeRequestHeader(reader)
if err != nil {