From fecf9cc5b89a23c9909becee781220a9a23db4a9 Mon Sep 17 00:00:00 2001 From: Shelikhoo Date: Sat, 13 Aug 2016 21:44:36 +0800 Subject: [PATCH] Added WebSocket transport --- transport/internet/ws/config.go | 21 +++ transport/internet/ws/config_json.go | 25 +++ transport/internet/ws/connection.go | 110 ++++++++++++ transport/internet/ws/connection_cache.go | 112 ++++++++++++ transport/internet/ws/connection_test.go | 19 +++ transport/internet/ws/dialer.go | 119 +++++++++++++ transport/internet/ws/hub.go | 162 ++++++++++++++++++ transport/internet/ws/wsconn.go | 199 ++++++++++++++++++++++ 8 files changed, 767 insertions(+) create mode 100644 transport/internet/ws/config.go create mode 100644 transport/internet/ws/config_json.go create mode 100644 transport/internet/ws/connection.go create mode 100644 transport/internet/ws/connection_cache.go create mode 100644 transport/internet/ws/connection_test.go create mode 100644 transport/internet/ws/dialer.go create mode 100644 transport/internet/ws/hub.go create mode 100644 transport/internet/ws/wsconn.go diff --git a/transport/internet/ws/config.go b/transport/internet/ws/config.go new file mode 100644 index 000000000..e078d607a --- /dev/null +++ b/transport/internet/ws/config.go @@ -0,0 +1,21 @@ +package ws + +type Config struct { + ConnectionReuse bool + Path string + Pto string + Cert string + PrivKey string +} + +func (this *Config) Apply() { + effectiveConfig = this +} + +var ( + effectiveConfig = &Config{ + ConnectionReuse: true, + Path: "", + Pto: "", + } +) diff --git a/transport/internet/ws/config_json.go b/transport/internet/ws/config_json.go new file mode 100644 index 000000000..843e549b1 --- /dev/null +++ b/transport/internet/ws/config_json.go @@ -0,0 +1,25 @@ +package ws + +import ( + "encoding/json" +) + +func (this *Config) UnmarshalJSON(data []byte) error { + type JsonConfig struct { + ConnectionReuse bool `json:"connectionReuse"` + Path string `json:"Path"` + Pto string `json:"Pto"` + } + jsonConfig := &JsonConfig{ + ConnectionReuse: true, + Path: "", + Pto: "", + } + if err := json.Unmarshal(data, jsonConfig); err != nil { + return err + } + this.ConnectionReuse = jsonConfig.ConnectionReuse + this.Path = jsonConfig.Path + this.Pto = jsonConfig.Pto + return nil +} diff --git a/transport/internet/ws/connection.go b/transport/internet/ws/connection.go new file mode 100644 index 000000000..141fdb32d --- /dev/null +++ b/transport/internet/ws/connection.go @@ -0,0 +1,110 @@ +package ws + +import ( + "errors" + "io" + "net" + "reflect" + "time" +) + +var ( + ErrInvalidConn = errors.New("Invalid Connection.") +) + +type ConnectionManager interface { + Recycle(string, *wsconn) +} + +type Connection struct { + dest string + conn *wsconn + listener ConnectionManager + reusable bool +} + +func NewConnection(dest string, conn *wsconn, manager ConnectionManager) *Connection { + return &Connection{ + dest: dest, + conn: conn, + listener: manager, + reusable: effectiveConfig.ConnectionReuse, + } +} + +func (this *Connection) Read(b []byte) (int, error) { + if this == nil || this.conn == nil { + return 0, io.EOF + } + + return this.conn.Read(b) +} + +func (this *Connection) Write(b []byte) (int, error) { + if this == nil || this.conn == nil { + return 0, io.ErrClosedPipe + } + return this.conn.Write(b) +} + +func (this *Connection) Close() error { + if this == nil || this.conn == nil { + return io.ErrClosedPipe + } + if this.Reusable() { + this.listener.Recycle(this.dest, this.conn) + return nil + } + err := this.conn.Close() + this.conn = nil + return err +} + +func (this *Connection) LocalAddr() net.Addr { + return this.conn.LocalAddr() +} + +func (this *Connection) RemoteAddr() net.Addr { + return this.conn.RemoteAddr() +} + +func (this *Connection) SetDeadline(t time.Time) error { + return this.conn.SetDeadline(t) +} + +func (this *Connection) SetReadDeadline(t time.Time) error { + return this.conn.SetReadDeadline(t) +} + +func (this *Connection) SetWriteDeadline(t time.Time) error { + return this.conn.SetWriteDeadline(t) +} + +func (this *Connection) SetReusable(reusable bool) { + if !effectiveConfig.ConnectionReuse { + return + } + this.reusable = reusable +} + +func (this *Connection) Reusable() bool { + return this.reusable +} + +func (this *Connection) SysFd() (int, error) { + return getSysFd(this.conn) +} + +func getSysFd(conn net.Conn) (int, error) { + cv := reflect.ValueOf(conn) + switch ce := cv.Elem(); ce.Kind() { + case reflect.Struct: + netfd := ce.FieldByName("conn").FieldByName("fd") + switch fe := netfd.Elem(); fe.Kind() { + case reflect.Struct: + fd := fe.FieldByName("sysfd") + return int(fd.Int()), nil + } + } + return 0, ErrInvalidConn +} diff --git a/transport/internet/ws/connection_cache.go b/transport/internet/ws/connection_cache.go new file mode 100644 index 000000000..e0743d756 --- /dev/null +++ b/transport/internet/ws/connection_cache.go @@ -0,0 +1,112 @@ +package ws + +import ( + "net" + "sync" + "time" + + "github.com/v2ray/v2ray-core/common/signal" +) + +type AwaitingConnection struct { + conn *wsconn + expire time.Time +} + +func (this *AwaitingConnection) Expired() bool { + return this.expire.Before(time.Now()) +} + +type ConnectionCache struct { + sync.Mutex + cache map[string][]*AwaitingConnection + cleanupOnce signal.Once +} + +func NewConnectionCache() *ConnectionCache { + return &ConnectionCache{ + cache: make(map[string][]*AwaitingConnection), + } +} + +func (this *ConnectionCache) Cleanup() { + defer this.cleanupOnce.Reset() + + for len(this.cache) > 0 { + time.Sleep(time.Second * 4) + this.Lock() + for key, value := range this.cache { + size := len(value) + changed := false + for i := 0; i < size; { + if value[i].Expired() { + value[i].conn.Close() + value[i] = value[size-1] + size-- + changed = true + } else { + i++ + } + } + if changed { + for i := size; i < len(value); i++ { + value[i] = nil + } + value = value[:size] + this.cache[key] = value + } + } + this.Unlock() + } +} + +func (this *ConnectionCache) Recycle(dest string, conn *wsconn) { + this.Lock() + defer this.Unlock() + + aconn := &AwaitingConnection{ + conn: conn, + expire: time.Now().Add(time.Second * 4), + } + + var list []*AwaitingConnection + if v, found := this.cache[dest]; found { + v = append(v, aconn) + list = v + } else { + list = []*AwaitingConnection{aconn} + } + this.cache[dest] = list + + go this.cleanupOnce.Do(this.Cleanup) +} + +func FindFirstValid(list []*AwaitingConnection) int { + for idx, conn := range list { + if !conn.Expired() && !conn.conn.connClosing { + return idx + } + go conn.conn.Close() + } + return -1 +} + +func (this *ConnectionCache) Get(dest string) net.Conn { + this.Lock() + defer this.Unlock() + + list, found := this.cache[dest] + if !found { + return nil + } + + firstValid := FindFirstValid(list) + if firstValid == -1 { + delete(this.cache, dest) + return nil + } + res := list[firstValid].conn + list = list[firstValid+1:] + this.cache[dest] = list + return res +} diff --git a/transport/internet/ws/connection_test.go b/transport/internet/ws/connection_test.go new file mode 100644 index 000000000..5e594f4e9 --- /dev/null +++ b/transport/internet/ws/connection_test.go @@ -0,0 +1,19 @@ +package ws_test + +import ( + "net" + "testing" + + "github.com/v2ray/v2ray-core/testing/assert" + . "github.com/v2ray/v2ray-core/transport/internet/tcp" +) + +func TestRawConnection(t *testing.T) { + assert := assert.On(t) + + rawConn := RawConnection{net.TCPConn{}} + assert.Bool(rawConn.Reusable()).IsFalse() + + rawConn.SetReusable(true) + assert.Bool(rawConn.Reusable()).IsFalse() +} diff --git a/transport/internet/ws/dialer.go b/transport/internet/ws/dialer.go new file mode 100644 index 000000000..bb1e88b0e --- /dev/null +++ b/transport/internet/ws/dialer.go @@ -0,0 +1,119 @@ +package ws + +import ( + "fmt" + "net" + + "github.com/gorilla/websocket" + "github.com/v2ray/v2ray-core/common/log" + v2net "github.com/v2ray/v2ray-core/common/net" + "github.com/v2ray/v2ray-core/transport/internet" +) + +var ( + globalCache = NewConnectionCache() +) + +func Dial(src v2net.Address, dest v2net.Destination) (internet.Connection, error) { + log.Info("Dailing WS to ", dest) + if src == nil { + src = v2net.AnyIP + } + id := src.String() + "-" + dest.NetAddr() + var conn *wsconn + if dest.IsTCP() && effectiveConfig.ConnectionReuse { + connt := globalCache.Get(id) + if connt != nil { + conn = connt.(*wsconn) + } + } + if conn == nil { + var err error + conn, err = wsDial(src, dest) + if err != nil { + log.Warning("WS Dial failed:" + err.Error()) + return nil, err + } + } + return NewConnection(id, conn, globalCache), nil +} + +func init() { + internet.WSDialer = Dial +} + +func wsDial(src v2net.Address, dest v2net.Destination) (*wsconn, error) { + //internet.DialToDest(src, dest) + commonDial := func(network, addr string) (net.Conn, error) { + return internet.DialToDest(src, dest) + } + + dialer := websocket.Dialer{NetDial: commonDial, ReadBufferSize: 65536, WriteBufferSize: 65536} + + effpto := func(dst v2net.Destination) string { + + if effectiveConfig.Pto != "" { + return effectiveConfig.Pto + } + + switch dst.Port().Value() { + /* + Since the value is not given explicitly, + We are guessing it now. + + HTTP Port: + 80 + 8080 + 8880 + 2052 + 2082 + 2086 + 2095 + + HTTPS Port: + 443 + 2053 + 2083 + 2087 + 2096 + 8443 + + if the port you are using is not well-known, + specify it to avoid this process. + + We will re return "CRASH"turn "unknown" if we can't guess it, cause Dial to fail. + */ + case 80: + case 8080: + case 8880: + case 2052: + case 2082: + case 2086: + case 2095: + return "ws" + case 443: + case 2053: + case 2083: + case 2087: + case 2096: + case 8443: + return "wss" + default: + return "unknown" + } + panic("Runtime unstable. Please report this bug to developers.") + }(dest) + + uri := func(dst v2net.Destination, pto string, path string) string { + return fmt.Sprintf("%v://%v:%v/%v", pto, dst.NetAddr(), dst.Port(), path) + }(dest, effpto, effectiveConfig.Path) + conn, _, err := dialer.Dial(uri, nil) + if err != nil { + return nil, err + } + return func() internet.Connection { + connv2ray := &wsconn{wsc: conn, connClosing: false} + connv2ray.setup() + return connv2ray + }().(*wsconn), nil +} diff --git a/transport/internet/ws/hub.go b/transport/internet/ws/hub.go new file mode 100644 index 000000000..cc4563c20 --- /dev/null +++ b/transport/internet/ws/hub.go @@ -0,0 +1,162 @@ +package ws + +import ( + "errors" + "net" + "net/http" + "strconv" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/v2ray/v2ray-core/common/log" + v2net "github.com/v2ray/v2ray-core/common/net" + "github.com/v2ray/v2ray-core/transport/internet" +) + +var ( + ErrClosedListener = errors.New("Listener is closed.") +) + +type ConnectionWithError struct { + conn net.Conn + err error +} + +type WSListener struct { + sync.Mutex + acccepting bool + awaitingConns chan *ConnectionWithError +} + +func ListenWS(address v2net.Address, port v2net.Port) (internet.Listener, error) { + + l := &WSListener{ + acccepting: true, + awaitingConns: make(chan *ConnectionWithError, 32), + } + + err := l.listenws(address, port) + + return l, err +} + +func (wsl *WSListener) listenws(address v2net.Address, port v2net.Port) error { + + http.HandleFunc("/"+effectiveConfig.Path, func(w http.ResponseWriter, r *http.Request) { + log.Warning("WS:WSListener->listenws->(HandleFunc,lambda 2)! Accepting websocket") + con, err := wsl.converttovws(w, r) + if err != nil { + log.Warning("WS:WSListener->listenws->(HandleFunc,lambda 2)!" + err.Error()) + return + } + + select { + case wsl.awaitingConns <- &ConnectionWithError{ + conn: con, + err: err, + }: + log.Warning("WS:WSListener->listenws->(HandleFunc,lambda 2)! transferd websocket") + default: + if con != nil { + con.Close() + } + } + //con.retloc.Wait() + return + + }) + + errchan := make(chan error) + + go func() { + err := http.ListenAndServe(address.String()+":"+strconv.Itoa(int(port.Value())), nil) + errchan <- err + return + }() + + var err error + select { + case err = <-errchan: + case <-time.After(time.Second * 2): + //Should this listen fail after 2 sec, it could gone untracked. + } + + if err != nil { + log.Error("WS:WSListener->listenws->ListenAndServe!" + err.Error()) + } + + return err + +} + +func (wsl *WSListener) converttovws(w http.ResponseWriter, r *http.Request) (*wsconn, error) { + var upgrader = websocket.Upgrader{ + ReadBufferSize: 65536, + WriteBufferSize: 65536, + } + conn, err := upgrader.Upgrade(w, r, nil) + + if err != nil { + return nil, err + } + + wrapedConn := &wsconn{wsc: conn, connClosing: false} + wrapedConn.setup() + return wrapedConn, nil +} + +func (this *WSListener) Accept() (internet.Connection, error) { + for this.acccepting { + select { + case connErr, open := <-this.awaitingConns: + log.Info("WSListener: conn accepted") + if !open { + return nil, ErrClosedListener + } + if connErr.err != nil { + return nil, connErr.err + } + return NewConnection("", connErr.conn.(*wsconn), this), nil + case <-time.After(time.Second * 2): + } + } + return nil, ErrClosedListener +} + +func (this *WSListener) Recycle(dest string, conn *wsconn) { + this.Lock() + defer this.Unlock() + if !this.acccepting { + return + } + select { + case this.awaitingConns <- &ConnectionWithError{conn: conn}: + default: + conn.Close() + } +} + +func (this *WSListener) Addr() net.Addr { + return nil +} + +func (this *WSListener) Close() error { + this.Lock() + defer this.Unlock() + this.acccepting = false + + log.Warning("WSListener: Yet to support close listening HTTP service") + + close(this.awaitingConns) + for connErr := range this.awaitingConns { + if connErr.conn != nil { + go connErr.conn.Close() + } + } + return nil +} + +func init() { + internet.WSListenFunc = ListenWS +} diff --git a/transport/internet/ws/wsconn.go b/transport/internet/ws/wsconn.go new file mode 100644 index 000000000..3215b7317 --- /dev/null +++ b/transport/internet/ws/wsconn.go @@ -0,0 +1,199 @@ +package ws + +import ( + "bufio" + "io" + "net" + "sync" + "time" + + "github.com/v2ray/v2ray-core/common/log" + + "github.com/gorilla/websocket" +) + +type wsconn struct { + wsc *websocket.Conn + readBuffer *bufio.Reader + connClosing bool + reusable bool + retloc *sync.Cond + rlock *sync.Mutex + wlock *sync.Mutex +} + +func (ws *wsconn) Read(b []byte) (n int, err error) { + + //defer ws.rlock.Unlock() + //ws.checkifRWAfterClosing() + if ws.connClosing { + + return 0, io.EOF + } + getNewBuffer := func() error { + _, r, err := ws.wsc.NextReader() + if err != nil { + log.Warning("WS transport: ws connection NewFrameReader return " + err.Error()) + ws.connClosing = true + ws.Close() + return err + } + ws.readBuffer = bufio.NewReader(r) + return nil + } + + readNext := func(b []byte) (n int, err error) { + if ws.readBuffer == nil { + err = getNewBuffer() + if err != nil { + //ws.Close() + return 0, err + } + } + + n, err = ws.readBuffer.Read(b) + + if err == nil { + return n, err + } + + if err == io.EOF { + ws.readBuffer = nil + if n == 0 { + return ws.Read(b) + } + return n, nil + } + //ws.Close() + return n, err + + } + n, err = readNext(b) + + return n, err + +} + +func (ws *wsconn) Write(b []byte) (n int, err error) { + + //defer + //ws.checkifRWAfterClosing() + if ws.connClosing { + + return 0, io.EOF + } + writeWs := func(b []byte) (n int, err error) { + wr, err := ws.wsc.NextWriter(websocket.BinaryMessage) + if err != nil { + log.Warning("WS transport: ws connection NewFrameReader return " + err.Error()) + ws.connClosing = true + ws.Close() + return 0, err + } + n, err = wr.Write(b) + if err != nil { + //ws.Close() + return 0, err + } + err = wr.Close() + if err != nil { + //ws.Close() + return 0, err + } + return n, err + } + n, err = writeWs(b) + return n, err +} +func (ws *wsconn) Close() error { + ws.connClosing = true + err := ws.wsc.Close() + ws.retloc.Broadcast() + return err +} +func (ws *wsconn) LocalAddr() net.Addr { + return ws.wsc.LocalAddr() +} +func (ws *wsconn) RemoteAddr() net.Addr { + return ws.wsc.RemoteAddr() +} +func (ws *wsconn) SetDeadline(t time.Time) error { + return func() error { + errr := ws.SetReadDeadline(t) + errw := ws.SetWriteDeadline(t) + if errr == nil || errw == nil { + return nil + } + if errr != nil { + return errr + } + + return errw + }() +} +func (ws *wsconn) SetReadDeadline(t time.Time) error { + return ws.wsc.SetReadDeadline(t) +} +func (ws *wsconn) SetWriteDeadline(t time.Time) error { + return ws.wsc.SetWriteDeadline(t) +} + +func (ws *wsconn) checkifRWAfterClosing() { + if ws.connClosing { + log.Error("WS transport: Read or Write After Conn have been marked closing, this can be dangerous.") + //panic("WS transport: Read or Write After Conn have been marked closing. Please report this crash to developer.") + } +} + +func (ws *wsconn) setup() { + ws.connClosing = false + + ws.rlock = &sync.Mutex{} + ws.wlock = &sync.Mutex{} + + initConnectedCond := func() { + rsl := &sync.Mutex{} + ws.retloc = sync.NewCond(rsl) + } + + initConnectedCond() + //ws.pingPong() +} + +func (ws *wsconn) Reusable() bool { + return ws.reusable && !ws.connClosing +} + +func (ws *wsconn) SetReusable(reusable bool) { + if !effectiveConfig.ConnectionReuse { + return + } + ws.reusable = reusable +} + +func (ws *wsconn) pingPong() { + pongRcv := make(chan int, 0) + ws.wsc.SetPongHandler(func(data string) error { + pongRcv <- 0 + return nil + }) + + go func() { + for !ws.connClosing { + ws.wsc.WriteMessage(websocket.PingMessage, nil) + tick := time.NewTicker(time.Second * 3) + + select { + case <-pongRcv: + break + case <-tick.C: + ws.Close() + } + <-tick.C + tick.Stop() + } + + return + }() + +}