1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-11-17 09:56:18 -05:00

Merge pull request #2 from v2ray/master

Update
This commit is contained in:
sunshineplan 2017-12-02 16:58:06 +08:00 committed by GitHub
commit d1fa98b60b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
225 changed files with 5693 additions and 4239 deletions

46
.github/CODE_OF_CONDUCT.md vendored Normal file
View File

@ -0,0 +1,46 @@
# Contributor Covenant Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful.
## Scope
This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at love@v2ray.com. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version]
[homepage]: http://contributor-covenant.org
[version]: http://contributor-covenant.org/version/1/4/

View File

@ -1,44 +1,94 @@
提交 Issue 之前请先阅读 [Issue 指引](https://www.v2ray.com/zh_cn/chapter_01/issue.html),然后回答下面的问题,谢谢。 Please skip to the English section below if you don't write Chinese.
Please read the [instruction](https://www.v2ray.com/en/get_started/issue.html) and answer the following questions before submitting your issue. Thank you.
中文:
提交 Issue 之前请先阅读 [Issue 指引](https://www.v2ray.com/chapter_01/issue.html),然后回答下面的问题,谢谢。
除非特殊情况,请完整填写所有问题。不按模板发的 issue 将直接被关闭。
1) 你正在使用哪个版本的 V2Ray如果服务器和客户端使用了不同版本请注明 1) 你正在使用哪个版本的 V2Ray如果服务器和客户端使用了不同版本请注明
What version of V2Ray are you using (If you deploy different version on server and client, please explicitly point out)?
2) 你的使用场景是什么?比如使用 Chrome 通过 Socks/VMess 代理观看 YouTube 视频。 2) 你的使用场景是什么?比如使用 Chrome 通过 Socks/VMess 代理观看 YouTube 视频。
What's your scenario of using V2Ray? E.g., Watching YouTube videos in Chrome via Socks/VMess proxy.
3) 你看到的不正常的现象是什么? 3) 你看到的不正常的现象是什么请描述具体现象比如访问超时TLS 证书错误等)
What did you see?
4) 你期待看到的正确表现是怎样的? 4) 你期待看到的正确表现是怎样的?
What's your expectation?
5) 请附上你的配置文件(提交 Issue 前请隐藏服务器端IP地址 5) 请附上你的配置(提交 Issue 前请隐藏服务器端IP地址
Please attach your configuration file (**Mask IP addresses before submit this issue**).
Server Configuration File服务器端配置文件): 服务器端配置:
```javascript ```javascript
// 在这里附上服务器端配置文件 // 在这里附上服务器端配置文件
// Please attach your server configuration file here.
``` ```
Client Configuration File客户端配置文件): 客户端配置:
```javascript ```javascript
// 在这里附上客户端配置文件 // 在这里附上客户端配置
// Please attach your client configuration file here.
``` ```
6) 请附上出错时软件输出的日志。在 Linux 中,日志通常在 `/var/log/v2ray/error.log` 文件中。 6) 请附上出错时软件输出的错误日志。在 Linux 中,日志通常在 `/var/log/v2ray/error.log` 文件中。
Please attach the log file, especially the bottom lines if the file is large. Log file is usually `/var/log/v2ray/error.log` on Linux.
Server Log File服务器端日志: 服务器端错误日志:
``` ```
// 在这里附上服务器端日志 // 在这里附上服务器端日志
// Please attach your server log here.
``` ```
Client Log File客户端日志: 客户端错误日志:
``` ```
// 在这里附上客户端日志 // 在这里附上客户端日志
// Please attach your client log here.
``` ```
7) 请附上访问日志。在 Linux 中,日志通常在 `/var/log/v2ray/error.log` 文件中。
```
// 在这里附上服务器端日志
```
8) 其它相关的配置文件(如 Nginx和相关日志。
请预览一下你填的内容再提交。
如果你已经填完上面的问卷,请把下面的英文部份删除,再提交 Issue。
Please remove the Chinese section above.
English:
Please read the [instruction](https://www.v2ray.com/en/get_started/issue.html) and answer the following questions before submitting your issue. Thank you.
Please answer all the questions with enough information. All issues not following this template will be closed immediately.
1) What version of V2Ray are you using (If you deploy different version on server and client, please explicitly point out)?
2) What's your scenario of using V2Ray? E.g., Watching YouTube videos in Chrome via Socks/VMess proxy.
3) What did you see? (Please describe in detail, such as timeout, fake TLS certificate etc)
4) What's your expectation?
5) Please attach your configuration file (**Mask IP addresses before submit this issue**).
Server configuration:
```javascript
// Please attach your server configuration here.
```
Client configuration:
```javascript
// Please attach your client configuration here.
```
6) Please attach error logs, especially the bottom lines if the file is large. Error log file is usually at `/var/log/v2ray/error.log` on Linux.
Server error log:
```
// Please attach your server error log here.
```
Client error log:
```
// Please attach your client error log here.
```
7) Please attach access log. Access log is usually at '/var/log/v2ray/access.log' on Linux.
```
// Please attach your server access log here.
```
Please review your issue before submitting.
8) Other configurations (such as Nginx) and logs.

63
.github/SUPPORT.md vendored Normal file
View File

@ -0,0 +1,63 @@
# V2Ray 用户支持 (User Support)
**English reader please skip to the [English section](#way-to-get-support) below**
## 获得帮助信息的途径
您可以从以下渠道获取帮助:
1. 官方网站:[v2ray.com](https://www.v2ray.com)
1. Github[Issues](https://github.com/v2ray/v2ray-core/issues)
1. Telegram[主群](https://t.me/projectv2ray)
## Github Issue 规则
1. 请按模板填写 issue
1. 配置文件内容使用格式化代码段进行修饰(见下面的解释);
1. 在提交 issue 前尝试减化配置文件,比如删除不必要 inbound / outbound 模块;
1. 在提交 issue 前尝试确定问题所在,比如将 socks 代理换成 http 再次观察问题是否能重现;
1. 配置文件必须结构完整,即除了必要的隐私信息之外,配置文件可以直接拿来运行。
**不按模板填写的 issue 将直接被关闭**
## 格式化代码段
在配置文件上下加入 Markdown 特定的修饰符,如下:
\`\`\`javascript
{
// 配置文件内容
}
\`\`\`
## Way to Get Support
You may get help in the following ways:
1. Office Site: [v2ray.com](https://www.v2ray.com)
1. Github: [Issues](https://github.com/v2ray/v2ray-core/issues)
1. Telegram: [Main Group](https://t.me/projectv2ray)
## Github Issue Rules
1. Please fill in the issue template.
1. Decorate config file with Markdown formatter (See below).
1. Try to simplify config file before submitting the issue, such as removing unnecessary inbound / outbound blocks.
1. Try to determine the cause of the issue, for example, replacing socks inbound with http inbound to see if the issue still exists.
1. Config file must be structurally complete.
**Any issue not following the issue template will be closed immediately.**
## Code formatter
Add the following Markdown decorator to config file content:
\`\`\`javascript
{
// config file
}
\`\`\`

6
.gitmodules vendored
View File

@ -1,3 +1,9 @@
[submodule "vendor/h12.me/socks"] [submodule "vendor/h12.me/socks"]
path = vendor/h12.me/socks path = vendor/h12.me/socks
url = https://github.com/h12w/socks url = https://github.com/h12w/socks
[submodule "vendor/github.com/shadowsocks/go-shadowsocks2"]
path = vendor/github.com/shadowsocks/go-shadowsocks2
url = https://github.com/shadowsocks/go-shadowsocks2
[submodule "vendor/github.com/Yawning/chacha20"]
path = vendor/github.com/Yawning/chacha20
url = https://github.com/Yawning/chacha20

View File

@ -1,7 +1,7 @@
sudo: required sudo: required
language: go language: go
go: go:
- 1.9 - 1.9.2
go_import_path: v2ray.com/core go_import_path: v2ray.com/core
git: git:
depth: 5 depth: 5

19
.vscode/tasks.json vendored
View File

@ -1,13 +1,18 @@
{ {
"version": "0.1.0", "version": "2.0.0",
"command": "go", "command": "go",
"isShellCommand": true, "type": "shell",
"showOutput": "always", "presentation": {
"echo": true,
"reveal": "always",
"focus": false,
"panel": "shared"
},
"tasks": [ "tasks": [
{ {
"taskName": "build", "label": "build",
"args": ["v2ray.com/core/..."], "args": ["v2ray.com/core/..."],
"isBuildCommand": true, "group": "build",
"problemMatcher": { "problemMatcher": {
"owner": "go", "owner": "go",
"fileLocation": ["relative", "${workspaceRoot}"], "fileLocation": ["relative", "${workspaceRoot}"],
@ -20,9 +25,9 @@
} }
}, },
{ {
"taskName": "test", "label": "test",
"args": ["-p", "1", "v2ray.com/core/..."], "args": ["-p", "1", "v2ray.com/core/..."],
"isBuildCommand": false "group": "test"
} }
] ]
} }

View File

@ -1,4 +1,4 @@
# Project V2Ray # Project V
[![Build Status][1]][2] [![codecov.io][3]][4] [![Go Report][5]][6] [![GoDoc][7]][8] [![codebeat][9]][10] [![Build Status][1]][2] [![codecov.io][3]][4] [![Go Report][5]][6] [![GoDoc][7]][8] [![codebeat][9]][10]
@ -13,11 +13,21 @@
[9]: https://codebeat.co/badges/f2354ca8-3e24-463d-a2e3-159af73b2477 "Codebeat badge" [9]: https://codebeat.co/badges/f2354ca8-3e24-463d-a2e3-159af73b2477 "Codebeat badge"
[10]: https://codebeat.co/projects/github-com-v2ray-v2ray-core-master "Codebeat" [10]: https://codebeat.co/projects/github-com-v2ray-v2ray-core-master "Codebeat"
V2Ray 是一个模块化的代理软件包,它的目标是提供常用的代理软件模块,简化网络代理软件的开发。 V 是一个模块化的代理软件包,它的目标是提供常用的代理软件模块,简化网络代理软件的开发。
[官方网站](https://www.v2ray.com/) [官方网站](https://www.v2ray.com/)
V2Ray provides building blocks for network proxy development. Read our [Wiki](https://www.v2ray.com/en/index.html) for more information. V provides building blocks for network proxy development. Read our [Wiki](https://www.v2ray.com/en/index.html) for more information.
## License ## License
[The MIT License (MIT)](https://raw.githubusercontent.com/v2ray/v2ray-core/master/LICENSE) [The MIT License (MIT)](https://raw.githubusercontent.com/v2ray/v2ray-core/master/LICENSE)
## Credits
This repo relies on the following third-party projects:
* In production:
* [miekg/dns](https://github.com/miekg/dns)
* [gorilla/websocket](https://github.com/gorilla/websocket)
* For testing only:
* [h12w/socks](https://github.com/h12w/socks)

View File

@ -39,7 +39,7 @@ func NewDefaultDispatcher(ctx context.Context, config *dispatcher.Config) (*Defa
return nil, newError("no space in context") return nil, newError("no space in context")
} }
d := &DefaultDispatcher{} d := &DefaultDispatcher{}
space.OnInitialize(func() error { space.On(app.SpaceInitializing, func(interface{}) error {
d.ohm = proxyman.OutboundHandlerManagerFromSpace(space) d.ohm = proxyman.OutboundHandlerManagerFromSpace(space)
if d.ohm == nil { if d.ohm == nil {
return newError("OutboundHandlerManager is not found in the space") return newError("OutboundHandlerManager is not found in the space")

View File

@ -3,12 +3,14 @@ package impl_test
import ( import (
"testing" "testing"
"v2ray.com/core/app/proxyman"
. "v2ray.com/core/app/dispatcher/impl" . "v2ray.com/core/app/dispatcher/impl"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func TestHTTPHeaders(t *testing.T) { func TestHTTPHeaders(t *testing.T) {
assert := assert.On(t) assert := With(t)
cases := []struct { cases := []struct {
input string input string
@ -94,13 +96,13 @@ first_name=John&last_name=Doe&action=Submit`,
for _, test := range cases { for _, test := range cases {
domain, err := SniffHTTP([]byte(test.input)) domain, err := SniffHTTP([]byte(test.input))
assert.String(domain).Equals(test.domain) assert(domain, Equals, test.domain)
assert.Error(err).Equals(test.err) assert(err, Equals, test.err)
} }
} }
func TestTLSHeaders(t *testing.T) { func TestTLSHeaders(t *testing.T) {
assert := assert.On(t) assert := With(t)
cases := []struct { cases := []struct {
input []byte input []byte
@ -180,7 +182,13 @@ func TestTLSHeaders(t *testing.T) {
for _, test := range cases { for _, test := range cases {
domain, err := SniffTLS(test.input) domain, err := SniffTLS(test.input)
assert.String(domain).Equals(test.domain) assert(domain, Equals, test.domain)
assert.Error(err).Equals(test.err) assert(err, Equals, test.err)
} }
} }
func TestUnknownSniffer(t *testing.T) {
assert := With(t)
assert(func() { NewSniffer([]proxyman.KnownProtocols{proxyman.KnownProtocols(-1)}) }, Panics)
}

View File

@ -8,6 +8,7 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
"v2ray.com/core/app/dispatcher" "v2ray.com/core/app/dispatcher"
"v2ray.com/core/app/log" "v2ray.com/core/app/log"
"v2ray.com/core/common"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/dice" "v2ray.com/core/common/dice"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
@ -15,13 +16,16 @@ import (
) )
const ( const (
DefaultTTL = uint32(3600)
CleanupInterval = time.Second * 120 CleanupInterval = time.Second * 120
CleanupThreshold = 512 CleanupThreshold = 512
) )
var ( var (
pseudoDestination = net.UDPDestination(net.LocalHostIP, net.Port(53)) multiQuestionDNS = map[net.Address]bool{
net.IPAddress([]byte{8, 8, 8, 8}): true,
net.IPAddress([]byte{8, 8, 4, 4}): true,
net.IPAddress([]byte{9, 9, 9, 9}): true,
}
) )
type ARecord struct { type ARecord struct {
@ -55,54 +59,52 @@ func NewUDPNameServer(address net.Destination, dispatcher dispatcher.Interface)
return s return s
} }
// Private: Visible for testing. func (s *UDPNameServer) Cleanup() {
func (v *UDPNameServer) Cleanup() {
expiredRequests := make([]uint16, 0, 16) expiredRequests := make([]uint16, 0, 16)
now := time.Now() now := time.Now()
v.Lock() s.Lock()
for id, r := range v.requests { for id, r := range s.requests {
if r.expire.Before(now) { if r.expire.Before(now) {
expiredRequests = append(expiredRequests, id) expiredRequests = append(expiredRequests, id)
close(r.response) close(r.response)
} }
} }
for _, id := range expiredRequests { for _, id := range expiredRequests {
delete(v.requests, id) delete(s.requests, id)
} }
v.Unlock() s.Unlock()
expiredRequests = nil
} }
// Private: Visible for testing. func (s *UDPNameServer) AssignUnusedID(response chan<- *ARecord) uint16 {
func (v *UDPNameServer) AssignUnusedID(response chan<- *ARecord) uint16 {
var id uint16 var id uint16
v.Lock() s.Lock()
if len(v.requests) > CleanupThreshold && v.nextCleanup.Before(time.Now()) { if len(s.requests) > CleanupThreshold && s.nextCleanup.Before(time.Now()) {
v.nextCleanup = time.Now().Add(CleanupInterval) s.nextCleanup = time.Now().Add(CleanupInterval)
go v.Cleanup() go s.Cleanup()
} }
for { for {
id = dice.RollUint16() id = dice.RollUint16()
if _, found := v.requests[id]; found { if _, found := s.requests[id]; found {
continue continue
} }
log.Trace(newError("add pending request id ", id).AtDebug()) log.Trace(newError("add pending request id ", id).AtDebug())
v.requests[id] = &PendingRequest{ s.requests[id] = &PendingRequest{
expire: time.Now().Add(time.Second * 8), expire: time.Now().Add(time.Second * 8),
response: response, response: response,
} }
break break
} }
v.Unlock() s.Unlock()
return id return id
} }
// Private: Visible for testing. func (s *UDPNameServer) HandleResponse(payload *buf.Buffer) {
func (v *UDPNameServer) HandleResponse(payload *buf.Buffer) {
msg := new(dns.Msg) msg := new(dns.Msg)
err := msg.Unpack(payload.Bytes()) err := msg.Unpack(payload.Bytes())
if err != nil { if err == dns.ErrTruncated {
log.Trace(newError("truncated message received. DNS server should still work. If you see anything abnormal, please submit an issue to v2ray-core.").AtWarning())
} else if err != nil {
log.Trace(newError("failed to parse DNS response").Base(err).AtWarning()) log.Trace(newError("failed to parse DNS response").Base(err).AtWarning())
return return
} }
@ -110,17 +112,17 @@ func (v *UDPNameServer) HandleResponse(payload *buf.Buffer) {
IPs: make([]net.IP, 0, 16), IPs: make([]net.IP, 0, 16),
} }
id := msg.Id id := msg.Id
ttl := DefaultTTL ttl := uint32(3600) // an hour
log.Trace(newError("handling response for id ", id, " content: ", msg.String()).AtDebug()) log.Trace(newError("handling response for id ", id, " content: ", msg).AtDebug())
v.Lock() s.Lock()
request, found := v.requests[id] request, found := s.requests[id]
if !found { if !found {
v.Unlock() s.Unlock()
return return
} }
delete(v.requests, id) delete(s.requests, id)
v.Unlock() s.Unlock()
for _, rr := range msg.Answer { for _, rr := range msg.Answer {
switch rr := rr.(type) { switch rr := rr.(type) {
@ -142,8 +144,7 @@ func (v *UDPNameServer) HandleResponse(payload *buf.Buffer) {
close(request.response) close(request.response)
} }
func (v *UDPNameServer) BuildQueryA(domain string, id uint16) *buf.Buffer { func (s *UDPNameServer) BuildQueryA(domain string, id uint16) *buf.Buffer {
msg := new(dns.Msg) msg := new(dns.Msg)
msg.Id = id msg.Id = id
msg.RecursionDesired = true msg.RecursionDesired = true
@ -153,34 +154,40 @@ func (v *UDPNameServer) BuildQueryA(domain string, id uint16) *buf.Buffer {
Qtype: dns.TypeA, Qtype: dns.TypeA,
Qclass: dns.ClassINET, Qclass: dns.ClassINET,
}} }}
if multiQuestionDNS[s.address.Address] {
msg.Question = append(msg.Question, dns.Question{
Name: dns.Fqdn(domain),
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
})
}
buffer := buf.New() buffer := buf.New()
buffer.AppendSupplier(func(b []byte) (int, error) { common.Must(buffer.Reset(func(b []byte) (int, error) {
writtenBuffer, err := msg.PackBuffer(b) writtenBuffer, err := msg.PackBuffer(b)
return len(writtenBuffer), err return len(writtenBuffer), err
}) }))
return buffer return buffer
} }
func (v *UDPNameServer) QueryA(domain string) <-chan *ARecord { func (s *UDPNameServer) QueryA(domain string) <-chan *ARecord {
response := make(chan *ARecord, 1) response := make(chan *ARecord, 1)
id := v.AssignUnusedID(response) id := s.AssignUnusedID(response)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*8) ctx, cancel := context.WithCancel(context.Background())
v.udpServer.Dispatch(ctx, v.address, v.BuildQueryA(domain, id), v.HandleResponse) s.udpServer.Dispatch(ctx, s.address, s.BuildQueryA(domain, id), s.HandleResponse)
go func() { go func() {
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
time.Sleep(time.Second) time.Sleep(time.Second)
v.Lock() s.Lock()
_, found := v.requests[id] _, found := s.requests[id]
v.Unlock() s.Unlock()
if found { if !found {
v.udpServer.Dispatch(ctx, v.address, v.BuildQueryA(domain, id), v.HandleResponse)
} else {
break break
} }
s.udpServer.Dispatch(ctx, s.address, s.BuildQueryA(domain, id), s.HandleResponse)
} }
cancel() cancel()
}() }()
@ -191,7 +198,7 @@ func (v *UDPNameServer) QueryA(domain string) <-chan *ARecord {
type LocalNameServer struct { type LocalNameServer struct {
} }
func (v *LocalNameServer) QueryA(domain string) <-chan *ARecord { func (*LocalNameServer) QueryA(domain string) <-chan *ARecord {
response := make(chan *ARecord, 1) response := make(chan *ARecord, 1)
go func() { go func() {
@ -205,7 +212,7 @@ func (v *LocalNameServer) QueryA(domain string) <-chan *ARecord {
response <- &ARecord{ response <- &ARecord{
IPs: ips, IPs: ips,
Expire: time.Now().Add(time.Second * time.Duration(DefaultTTL)), Expire: time.Now().Add(time.Hour),
} }
}() }()

View File

@ -21,11 +21,22 @@ const (
) )
type DomainRecord struct { type DomainRecord struct {
A *ARecord IP []net.IP
Expire time.Time
LastAccess time.Time
}
func (r *DomainRecord) Expired() bool {
return r.Expire.Before(time.Now())
}
func (r *DomainRecord) Inactive() bool {
now := time.Now()
return r.Expire.Before(now) || r.LastAccess.Add(time.Minute*5).Before(now)
} }
type CacheServer struct { type CacheServer struct {
sync.RWMutex sync.Mutex
hosts map[string]net.IP hosts map[string]net.IP
records map[string]*DomainRecord records map[string]*DomainRecord
servers []NameServer servers []NameServer
@ -41,7 +52,7 @@ func NewCacheServer(ctx context.Context, config *dns.Config) (*CacheServer, erro
servers: make([]NameServer, len(config.NameServers)), servers: make([]NameServer, len(config.NameServers)),
hosts: config.GetInternalHosts(), hosts: config.GetInternalHosts(),
} }
space.OnInitialize(func() error { space.On(app.SpaceInitializing, func(interface{}) error {
disp := dispatcher.FromSpace(space) disp := dispatcher.FromSpace(space)
if disp == nil { if disp == nil {
return newError("dispatcher is not found in the space") return newError("dispatcher is not found in the space")
@ -79,15 +90,33 @@ func (*CacheServer) Start() error {
func (*CacheServer) Close() {} func (*CacheServer) Close() {}
func (s *CacheServer) GetCached(domain string) []net.IP { func (s *CacheServer) GetCached(domain string) []net.IP {
s.RLock() s.Lock()
defer s.RUnlock() defer s.Unlock()
if record, found := s.records[domain]; found && record.A.Expire.After(time.Now()) { if record, found := s.records[domain]; found && !record.Expired() {
return record.A.IPs record.LastAccess = time.Now()
return record.IP
} }
return nil return nil
} }
func (s *CacheServer) tryCleanup() {
s.Lock()
defer s.Unlock()
if len(s.records) > 256 {
domains := make([]string, 0, 256)
for d, r := range s.records {
if r.Expired() {
domains = append(domains, d)
}
}
for _, d := range domains {
delete(s.records, d)
}
}
}
func (s *CacheServer) Get(domain string) []net.IP { func (s *CacheServer) Get(domain string) []net.IP {
if ip, found := s.hosts[domain]; found { if ip, found := s.hosts[domain]; found {
return []net.IP{ip} return []net.IP{ip}
@ -99,6 +128,8 @@ func (s *CacheServer) Get(domain string) []net.IP {
return ips return ips
} }
s.tryCleanup()
for _, server := range s.servers { for _, server := range s.servers {
response := server.QueryA(domain) response := server.QueryA(domain)
select { select {
@ -108,7 +139,9 @@ func (s *CacheServer) Get(domain string) []net.IP {
} }
s.Lock() s.Lock()
s.records[domain] = &DomainRecord{ s.records[domain] = &DomainRecord{
A: a, IP: a.IPs,
Expire: a.Expire,
LastAccess: time.Now(),
} }
s.Unlock() s.Unlock()
log.Trace(newError("returning ", len(a.IPs), " IPs for domain ", domain).AtDebug()) log.Trace(newError("returning ", len(a.IPs), " IPs for domain ", domain).AtDebug())

View File

@ -0,0 +1,105 @@
package server_test
import (
"context"
"testing"
"v2ray.com/core/app"
"v2ray.com/core/app/dispatcher"
_ "v2ray.com/core/app/dispatcher/impl"
. "v2ray.com/core/app/dns"
_ "v2ray.com/core/app/dns/server"
"v2ray.com/core/app/policy"
_ "v2ray.com/core/app/policy/manager"
"v2ray.com/core/app/proxyman"
_ "v2ray.com/core/app/proxyman/outbound"
"v2ray.com/core/common"
"v2ray.com/core/common/net"
"v2ray.com/core/common/serial"
"v2ray.com/core/proxy/freedom"
"v2ray.com/core/testing/servers/udp"
. "v2ray.com/ext/assert"
"github.com/miekg/dns"
)
type staticHandler struct {
}
func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
ans := new(dns.Msg)
ans.Id = r.Id
for _, q := range r.Question {
if q.Name == "google.com." && q.Qtype == dns.TypeA {
rr, _ := dns.NewRR("google.com. IN A 8.8.8.8")
ans.Answer = append(ans.Answer, rr)
} else if q.Name == "facebook.com." && q.Qtype == dns.TypeA {
rr, _ := dns.NewRR("facebook.com. IN A 9.9.9.9")
ans.Answer = append(ans.Answer, rr)
}
}
w.WriteMsg(ans)
}
func TestUDPServer(t *testing.T) {
assert := With(t)
port := udp.PickPort()
dnsServer := dns.Server{
Addr: "127.0.0.1:" + port.String(),
Net: "udp",
Handler: &staticHandler{},
UDPSize: 1200,
}
go dnsServer.ListenAndServe()
config := &Config{
NameServers: []*net.Endpoint{
{
Network: net.Network_UDP,
Address: &net.IPOrDomain{
Address: &net.IPOrDomain_Ip{
Ip: []byte{127, 0, 0, 1},
},
},
Port: uint32(port),
},
},
}
ctx := context.Background()
space := app.NewSpace()
ctx = app.ContextWithSpace(ctx, space)
common.Must(app.AddApplicationToSpace(ctx, config))
common.Must(app.AddApplicationToSpace(ctx, &dispatcher.Config{}))
common.Must(app.AddApplicationToSpace(ctx, &proxyman.OutboundConfig{}))
common.Must(app.AddApplicationToSpace(ctx, &policy.Config{}))
om := proxyman.OutboundHandlerManagerFromSpace(space)
om.AddHandler(ctx, &proxyman.OutboundHandlerConfig{
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
})
common.Must(space.Initialize())
common.Must(space.Start())
server := FromSpace(space)
assert(server, IsNotNil)
ips := server.Get("google.com")
assert(len(ips), Equals, 1)
assert([]byte(ips[0]), Equals, []byte{8, 8, 8, 8})
ips = server.Get("facebook.com")
assert(len(ips), Equals, 1)
assert([]byte(ips[0]), Equals, []byte{9, 9, 9, 9})
dnsServer.Shutdown()
ips = server.Get("google.com")
assert(len(ips), Equals, 1)
assert([]byte(ips[0]), Equals, []byte{8, 8, 8, 8})
}

View File

@ -4,11 +4,11 @@ import (
"testing" "testing"
. "v2ray.com/core/app/log/internal" . "v2ray.com/core/app/log/internal"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func TestAccessLog(t *testing.T) { func TestAccessLog(t *testing.T) {
assert := assert.On(t) assert := With(t)
entry := &AccessLog{ entry := &AccessLog{
From: "test_from", From: "test_from",
@ -18,8 +18,8 @@ func TestAccessLog(t *testing.T) {
} }
entryStr := entry.String() entryStr := entry.String()
assert.String(entryStr).Contains("test_from") assert(entryStr, HasSubstring, "test_from")
assert.String(entryStr).Contains("test_to") assert(entryStr, HasSubstring, "test_to")
assert.String(entryStr).Contains("test_reason") assert(entryStr, HasSubstring, "test_reason")
assert.String(entryStr).Contains("Accepted") assert(entryStr, HasSubstring, "Accepted")
} }

26
app/policy/config.go Normal file
View File

@ -0,0 +1,26 @@
package policy
import (
"time"
)
func (s *Second) Duration() time.Duration {
return time.Second * time.Duration(s.Value)
}
func (p *Policy) OverrideWith(another *Policy) {
if another.Timeout != nil {
if another.Timeout.Handshake != nil {
p.Timeout.Handshake = another.Timeout.Handshake
}
if another.Timeout.ConnectionIdle != nil {
p.Timeout.ConnectionIdle = another.Timeout.ConnectionIdle
}
if another.Timeout.UplinkOnly != nil {
p.Timeout.UplinkOnly = another.Timeout.UplinkOnly
}
if another.Timeout.DownlinkOnly != nil {
p.Timeout.DownlinkOnly = another.Timeout.DownlinkOnly
}
}
}

140
app/policy/config.pb.go Normal file
View File

@ -0,0 +1,140 @@
package policy
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type Second struct {
Value uint32 `protobuf:"varint,1,opt,name=value" json:"value,omitempty"`
}
func (m *Second) Reset() { *m = Second{} }
func (m *Second) String() string { return proto.CompactTextString(m) }
func (*Second) ProtoMessage() {}
func (*Second) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
func (m *Second) GetValue() uint32 {
if m != nil {
return m.Value
}
return 0
}
type Policy struct {
Timeout *Policy_Timeout `protobuf:"bytes,1,opt,name=timeout" json:"timeout,omitempty"`
}
func (m *Policy) Reset() { *m = Policy{} }
func (m *Policy) String() string { return proto.CompactTextString(m) }
func (*Policy) ProtoMessage() {}
func (*Policy) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
func (m *Policy) GetTimeout() *Policy_Timeout {
if m != nil {
return m.Timeout
}
return nil
}
// Timeout is a message for timeout settings in various stages, in seconds.
type Policy_Timeout struct {
Handshake *Second `protobuf:"bytes,1,opt,name=handshake" json:"handshake,omitempty"`
ConnectionIdle *Second `protobuf:"bytes,2,opt,name=connection_idle,json=connectionIdle" json:"connection_idle,omitempty"`
UplinkOnly *Second `protobuf:"bytes,3,opt,name=uplink_only,json=uplinkOnly" json:"uplink_only,omitempty"`
DownlinkOnly *Second `protobuf:"bytes,4,opt,name=downlink_only,json=downlinkOnly" json:"downlink_only,omitempty"`
}
func (m *Policy_Timeout) Reset() { *m = Policy_Timeout{} }
func (m *Policy_Timeout) String() string { return proto.CompactTextString(m) }
func (*Policy_Timeout) ProtoMessage() {}
func (*Policy_Timeout) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1, 0} }
func (m *Policy_Timeout) GetHandshake() *Second {
if m != nil {
return m.Handshake
}
return nil
}
func (m *Policy_Timeout) GetConnectionIdle() *Second {
if m != nil {
return m.ConnectionIdle
}
return nil
}
func (m *Policy_Timeout) GetUplinkOnly() *Second {
if m != nil {
return m.UplinkOnly
}
return nil
}
func (m *Policy_Timeout) GetDownlinkOnly() *Second {
if m != nil {
return m.DownlinkOnly
}
return nil
}
type Config struct {
Level map[uint32]*Policy `protobuf:"bytes,1,rep,name=level" json:"level,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
}
func (m *Config) Reset() { *m = Config{} }
func (m *Config) String() string { return proto.CompactTextString(m) }
func (*Config) ProtoMessage() {}
func (*Config) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
func (m *Config) GetLevel() map[uint32]*Policy {
if m != nil {
return m.Level
}
return nil
}
func init() {
proto.RegisterType((*Second)(nil), "v2ray.core.app.policy.Second")
proto.RegisterType((*Policy)(nil), "v2ray.core.app.policy.Policy")
proto.RegisterType((*Policy_Timeout)(nil), "v2ray.core.app.policy.Policy.Timeout")
proto.RegisterType((*Config)(nil), "v2ray.core.app.policy.Config")
}
func init() { proto.RegisterFile("v2ray.com/core/app/policy/config.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
// 349 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xc1, 0x4a, 0xeb, 0x40,
0x14, 0x86, 0x49, 0x7a, 0x9b, 0x72, 0x4f, 0x6f, 0xaf, 0x32, 0x58, 0x88, 0x05, 0xa5, 0x14, 0x94,
0xae, 0x26, 0x90, 0x6e, 0x44, 0xb1, 0x62, 0x45, 0x41, 0x10, 0x2c, 0x51, 0x14, 0xdc, 0x94, 0x71,
0x32, 0xda, 0xd0, 0xe9, 0x9c, 0x21, 0xa6, 0x95, 0xbc, 0x86, 0x6f, 0xe0, 0xd6, 0x87, 0xf2, 0x59,
0x24, 0x99, 0x84, 0x6c, 0x5a, 0xe9, 0x6e, 0x72, 0xf8, 0xfe, 0x8f, 0x43, 0xfe, 0x03, 0x87, 0x4b,
0x3f, 0x66, 0x29, 0xe5, 0x38, 0xf7, 0x38, 0xc6, 0xc2, 0x63, 0x5a, 0x7b, 0x1a, 0x65, 0xc4, 0x53,
0x8f, 0xa3, 0x7a, 0x89, 0x5e, 0xa9, 0x8e, 0x31, 0x41, 0xd2, 0x2e, 0xb9, 0x58, 0x50, 0xa6, 0x35,
0x35, 0x4c, 0x6f, 0x1f, 0x9c, 0x3b, 0xc1, 0x51, 0x85, 0x64, 0x07, 0xea, 0x4b, 0x26, 0x17, 0xc2,
0xb5, 0xba, 0x56, 0xbf, 0x15, 0x98, 0x8f, 0xde, 0xb7, 0x0d, 0xce, 0x38, 0x47, 0xc9, 0x19, 0x34,
0x92, 0x68, 0x2e, 0x70, 0x91, 0xe4, 0x48, 0xd3, 0x3f, 0xa0, 0x2b, 0x9d, 0xd4, 0xf0, 0xf4, 0xde,
0xc0, 0x41, 0x99, 0xea, 0x7c, 0xd8, 0xd0, 0x28, 0x86, 0xe4, 0x04, 0xfe, 0x4e, 0x99, 0x0a, 0xdf,
0xa6, 0x6c, 0x26, 0x0a, 0xdd, 0xde, 0x1a, 0x9d, 0xd9, 0x2f, 0xa8, 0x78, 0x72, 0x05, 0x5b, 0x1c,
0x95, 0x12, 0x3c, 0x89, 0x50, 0x4d, 0xa2, 0x50, 0x0a, 0xd7, 0xde, 0x44, 0xf1, 0xbf, 0x4a, 0x5d,
0x87, 0x52, 0x90, 0x21, 0x34, 0x17, 0x5a, 0x46, 0x6a, 0x36, 0x41, 0x25, 0x53, 0xb7, 0xb6, 0x89,
0x03, 0x4c, 0xe2, 0x56, 0xc9, 0x94, 0x8c, 0xa0, 0x15, 0xe2, 0xbb, 0xaa, 0x0c, 0x7f, 0x36, 0x31,
0xfc, 0x2b, 0x33, 0x99, 0xa3, 0xf7, 0x69, 0x81, 0x73, 0x91, 0x17, 0x45, 0x86, 0x50, 0x97, 0x62,
0x29, 0xa4, 0x6b, 0x75, 0x6b, 0xfd, 0xa6, 0xdf, 0x5f, 0xa3, 0x31, 0x34, 0xbd, 0xc9, 0xd0, 0x4b,
0x95, 0xc4, 0x69, 0x60, 0x62, 0x9d, 0x47, 0x80, 0x6a, 0x48, 0xb6, 0xa1, 0x36, 0x13, 0x69, 0xd1,
0x66, 0xf6, 0x24, 0x83, 0xb2, 0xe1, 0xdf, 0x7f, 0x96, 0xa9, 0xaf, 0x38, 0x80, 0x63, 0xfb, 0xc8,
0x1a, 0x9d, 0xc2, 0x2e, 0xc7, 0xf9, 0x6a, 0x7c, 0x6c, 0x3d, 0x39, 0xe6, 0xf5, 0x65, 0xb7, 0x1f,
0xfc, 0x80, 0x65, 0x0b, 0xc6, 0x82, 0x9e, 0x6b, 0x5d, 0x98, 0x9e, 0x9d, 0xfc, 0x02, 0x07, 0x3f,
0x01, 0x00, 0x00, 0xff, 0xff, 0xcf, 0x25, 0x25, 0xc2, 0xab, 0x02, 0x00, 0x00,
}

27
app/policy/config.proto Normal file
View File

@ -0,0 +1,27 @@
syntax = "proto3";
package v2ray.core.app.policy;
option csharp_namespace = "V2Ray.Core.App.Policy";
option go_package = "policy";
option java_package = "com.v2ray.core.app.policy";
option java_multiple_files = true;
message Second {
uint32 value = 1;
}
message Policy {
// Timeout is a message for timeout settings in various stages, in seconds.
message Timeout {
Second handshake = 1;
Second connection_idle = 2;
Second uplink_only = 3;
Second downlink_only = 4;
}
Timeout timeout = 1;
}
message Config {
map<uint32, Policy> level = 1;
}

View File

@ -0,0 +1,66 @@
package manager
import (
"context"
"v2ray.com/core/app/policy"
"v2ray.com/core/common"
)
type Instance struct {
levels map[uint32]*policy.Policy
}
func New(ctx context.Context, config *policy.Config) (*Instance, error) {
levels := config.Level
if levels == nil {
levels = make(map[uint32]*policy.Policy)
}
for _, p := range levels {
g := global()
g.OverrideWith(p)
*p = g
}
return &Instance{
levels: levels,
}, nil
}
func global() policy.Policy {
return policy.Policy{
Timeout: &policy.Policy_Timeout{
Handshake: &policy.Second{Value: 4},
ConnectionIdle: &policy.Second{Value: 300},
UplinkOnly: &policy.Second{Value: 5},
DownlinkOnly: &policy.Second{Value: 30},
},
}
}
// GetPolicy implements policy.Manager.
func (m *Instance) GetPolicy(level uint32) policy.Policy {
if p, ok := m.levels[level]; ok {
return *p
}
return global()
}
// Start implements app.Application.Start().
func (m *Instance) Start() error {
return nil
}
// Close implements app.Application.Close().
func (m *Instance) Close() {
}
// Interface implement app.Application.Interface().
func (m *Instance) Interface() interface{} {
return (*policy.Manager)(nil)
}
func init() {
common.Must(common.RegisterConfig((*policy.Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
return New(ctx, config.(*policy.Config))
}))
}

20
app/policy/policy.go Normal file
View File

@ -0,0 +1,20 @@
package policy
import (
"v2ray.com/core/app"
)
// Manager is an utility to manage policy per user level.
type Manager interface {
// GetPolicy returns the Policy for the given user level.
GetPolicy(level uint32) Policy
}
// FromSpace returns the policy.Manager in a space.
func FromSpace(space app.Space) Manager {
app := space.GetApplication((*Manager)(nil))
if app == nil {
return nil
}
return app.(Manager)
}

View File

@ -172,6 +172,11 @@ func (*udpConn) SetWriteDeadline(time.Time) error {
return nil return nil
} }
type connId struct {
src net.Destination
dest net.Destination
}
type udpWorker struct { type udpWorker struct {
sync.RWMutex sync.RWMutex
@ -185,39 +190,43 @@ type udpWorker struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
activeConn map[net.Destination]*udpConn activeConn map[connId]*udpConn
} }
func (w *udpWorker) getConnection(src net.Destination) (*udpConn, bool) { func (w *udpWorker) getConnection(id connId) (*udpConn, bool) {
w.Lock() w.Lock()
defer w.Unlock() defer w.Unlock()
if conn, found := w.activeConn[src]; found { if conn, found := w.activeConn[id]; found {
return conn, true return conn, true
} }
conn := &udpConn{ conn := &udpConn{
input: make(chan *buf.Buffer, 32), input: make(chan *buf.Buffer, 32),
output: func(b []byte) (int, error) { output: func(b []byte) (int, error) {
return w.hub.WriteTo(b, src) return w.hub.WriteTo(b, id.src)
}, },
remote: &net.UDPAddr{ remote: &net.UDPAddr{
IP: src.Address.IP(), IP: id.src.Address.IP(),
Port: int(src.Port), Port: int(id.src.Port),
}, },
local: &net.UDPAddr{ local: &net.UDPAddr{
IP: w.address.IP(), IP: w.address.IP(),
Port: int(w.port), Port: int(w.port),
}, },
} }
w.activeConn[src] = conn w.activeConn[id] = conn
conn.updateActivity() conn.updateActivity()
return conn, false return conn, false
} }
func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest net.Destination) { func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest net.Destination) {
conn, existing := w.getConnection(source) id := connId{
src: source,
dest: originalDest,
}
conn, existing := w.getConnection(id)
select { select {
case conn.input <- b: case conn.input <- b:
default: default:
@ -240,20 +249,20 @@ func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest
if err := w.proxy.Process(ctx, net.Network_UDP, conn, w.dispatcher); err != nil { if err := w.proxy.Process(ctx, net.Network_UDP, conn, w.dispatcher); err != nil {
log.Trace(newError("connection ends").Base(err)) log.Trace(newError("connection ends").Base(err))
} }
w.removeConn(source) w.removeConn(id)
cancel() cancel()
}() }()
} }
} }
func (w *udpWorker) removeConn(src net.Destination) { func (w *udpWorker) removeConn(id connId) {
w.Lock() w.Lock()
delete(w.activeConn, src) delete(w.activeConn, id)
w.Unlock() w.Unlock()
} }
func (w *udpWorker) Start() error { func (w *udpWorker) Start() error {
w.activeConn = make(map[net.Destination]*udpConn) w.activeConn = make(map[connId]*udpConn, 16)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
w.ctx = ctx w.ctx = ctx
w.cancel = cancel w.cancel = cancel

View File

@ -1,8 +1,10 @@
package mux package mux
import ( import (
"v2ray.com/core/common/bitmask"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/protocol"
"v2ray.com/core/common/serial" "v2ray.com/core/common/serial"
) )
@ -15,24 +17,10 @@ const (
SessionStatusKeepAlive SessionStatus = 0x04 SessionStatusKeepAlive SessionStatus = 0x04
) )
type Option byte
const ( const (
OptionData Option = 0x01 OptionData bitmask.Byte = 0x01
) )
func (o Option) Has(x Option) bool {
return (o & x) == x
}
func (o *Option) Add(x Option) {
*o = (*o | x)
}
func (o *Option) Clear(x Option) {
*o = (*o & (^x))
}
type TargetNetwork byte type TargetNetwork byte
const ( const (
@ -40,14 +28,6 @@ const (
TargetNetworkUDP TargetNetwork = 0x02 TargetNetworkUDP TargetNetwork = 0x02
) )
type AddressType byte
const (
AddressTypeIPv4 AddressType = 0x01
AddressTypeDomain AddressType = 0x02
AddressTypeIPv6 AddressType = 0x03
)
/* /*
Frame format Frame format
2 bytes - length 2 bytes - length
@ -62,10 +42,10 @@ n bytes - address
*/ */
type FrameMetadata struct { type FrameMetadata struct {
SessionID uint16
SessionStatus SessionStatus
Target net.Destination Target net.Destination
Option Option SessionID uint16
Option bitmask.Byte
SessionStatus SessionStatus
} }
func (f FrameMetadata) AsSupplier() buf.Supplier { func (f FrameMetadata) AsSupplier() buf.Supplier {
@ -92,17 +72,21 @@ func (f FrameMetadata) AsSupplier() buf.Supplier {
addr := f.Target.Address addr := f.Target.Address
switch addr.Family() { switch addr.Family() {
case net.AddressFamilyIPv4: case net.AddressFamilyIPv4:
b = append(b, byte(AddressTypeIPv4)) b = append(b, byte(protocol.AddressTypeIPv4))
b = append(b, addr.IP()...) b = append(b, addr.IP()...)
length += 5 length += 5
case net.AddressFamilyIPv6: case net.AddressFamilyIPv6:
b = append(b, byte(AddressTypeIPv6)) b = append(b, byte(protocol.AddressTypeIPv6))
b = append(b, addr.IP()...) b = append(b, addr.IP()...)
length += 17 length += 17
case net.AddressFamilyDomain: case net.AddressFamilyDomain:
nDomain := len(addr.Domain()) domain := addr.Domain()
b = append(b, byte(AddressTypeDomain), byte(nDomain)) if protocol.IsDomainTooLong(domain) {
b = append(b, addr.Domain()...) return 0, newError("domain name too long: ", domain)
}
nDomain := len(domain)
b = append(b, byte(protocol.AddressTypeDomain), byte(nDomain))
b = append(b, domain...)
length += nDomain + 2 length += nDomain + 2
} }
} }
@ -120,7 +104,7 @@ func ReadFrameFrom(b []byte) (*FrameMetadata, error) {
f := &FrameMetadata{ f := &FrameMetadata{
SessionID: serial.BytesToUint16(b[:2]), SessionID: serial.BytesToUint16(b[:2]),
SessionStatus: SessionStatus(b[2]), SessionStatus: SessionStatus(b[2]),
Option: Option(b[3]), Option: bitmask.Byte(b[3]),
} }
b = b[4:] b = b[4:]
@ -128,18 +112,18 @@ func ReadFrameFrom(b []byte) (*FrameMetadata, error) {
if f.SessionStatus == SessionStatusNew { if f.SessionStatus == SessionStatusNew {
network := TargetNetwork(b[0]) network := TargetNetwork(b[0])
port := net.PortFromBytes(b[1:3]) port := net.PortFromBytes(b[1:3])
addrType := AddressType(b[3]) addrType := protocol.AddressType(b[3])
b = b[4:] b = b[4:]
var addr net.Address var addr net.Address
switch addrType { switch addrType {
case AddressTypeIPv4: case protocol.AddressTypeIPv4:
addr = net.IPAddress(b[0:4]) addr = net.IPAddress(b[0:4])
b = b[4:] b = b[4:]
case AddressTypeIPv6: case protocol.AddressTypeIPv6:
addr = net.IPAddress(b[0:16]) addr = net.IPAddress(b[0:16])
b = b[16:] b = b[16:]
case AddressTypeDomain: case protocol.AddressTypeDomain:
nDomain := int(b[0]) nDomain := int(b[0])
addr = net.DomainAddress(string(b[1 : 1+nDomain])) addr = net.DomainAddress(string(b[1 : 1+nDomain]))
b = b[nDomain+1:] b = b[nDomain+1:]

View File

@ -90,7 +90,19 @@ func NewClient(p proxy.Outbound, dialer proxy.Dialer, m *ClientManager) (*Client
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
ctx = proxy.ContextWithTarget(ctx, net.TCPDestination(muxCoolAddress, muxCoolPort)) ctx = proxy.ContextWithTarget(ctx, net.TCPDestination(muxCoolAddress, muxCoolPort))
pipe := ray.NewRay(ctx) pipe := ray.NewRay(ctx)
go p.Process(ctx, pipe, dialer)
go func() {
if err := p.Process(ctx, pipe, dialer); err != nil {
cancel()
traceErr := errors.New("failed to handler mux client connection").Base(err)
if err != io.EOF && err != context.Canceled {
traceErr = traceErr.AtWarning()
}
log.Trace(traceErr)
}
}()
c := &Client{ c := &Client{
sessionManager: NewSessionManager(), sessionManager: NewSessionManager(),
inboundRay: pipe, inboundRay: pipe,
@ -104,6 +116,7 @@ func NewClient(p proxy.Outbound, dialer proxy.Dialer, m *ClientManager) (*Client
return c, nil return c, nil
} }
// Closed returns true if this Client is closed.
func (m *Client) Closed() bool { func (m *Client) Closed() bool {
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
@ -148,7 +161,7 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
log.Trace(newError("dispatching request to ", dest)) log.Trace(newError("dispatching request to ", dest))
data, _ := s.input.ReadTimeout(time.Millisecond * 500) data, _ := s.input.ReadTimeout(time.Millisecond * 500)
if err := writer.Write(data); err != nil { if err := writer.WriteMultiBuffer(data); err != nil {
log.Trace(newError("failed to write first payload").Base(err)) log.Trace(newError("failed to write first payload").Base(err))
return return
} }
@ -179,26 +192,25 @@ func (m *Client) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) bool
return true return true
} }
func drain(reader io.Reader) error { func drain(reader *buf.BufferedReader) error {
buf.Copy(NewStreamReader(reader), buf.Discard) return buf.Copy(NewStreamReader(reader), buf.Discard)
return nil
} }
func (m *Client) handleStatueKeepAlive(meta *FrameMetadata, reader io.Reader) error { func (m *Client) handleStatueKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error {
if meta.Option.Has(OptionData) { if meta.Option.Has(OptionData) {
return drain(reader) return drain(reader)
} }
return nil return nil
} }
func (m *Client) handleStatusNew(meta *FrameMetadata, reader io.Reader) error { func (m *Client) handleStatusNew(meta *FrameMetadata, reader *buf.BufferedReader) error {
if meta.Option.Has(OptionData) { if meta.Option.Has(OptionData) {
return drain(reader) return drain(reader)
} }
return nil return nil
} }
func (m *Client) handleStatusKeep(meta *FrameMetadata, reader io.Reader) error { func (m *Client) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error {
if !meta.Option.Has(OptionData) { if !meta.Option.Has(OptionData) {
return nil return nil
} }
@ -209,7 +221,7 @@ func (m *Client) handleStatusKeep(meta *FrameMetadata, reader io.Reader) error {
return drain(reader) return drain(reader)
} }
func (m *Client) handleStatusEnd(meta *FrameMetadata, reader io.Reader) error { func (m *Client) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error {
if s, found := m.sessionManager.Get(meta.SessionID); found { if s, found := m.sessionManager.Get(meta.SessionID); found {
s.Close() s.Close()
} }
@ -222,11 +234,10 @@ func (m *Client) handleStatusEnd(meta *FrameMetadata, reader io.Reader) error {
func (m *Client) fetchOutput() { func (m *Client) fetchOutput() {
defer m.cancel() defer m.cancel()
reader := buf.ToBytesReader(m.inboundRay.InboundOutput()) reader := buf.NewBufferedReader(m.inboundRay.InboundOutput())
metaReader := NewMetadataReader(reader)
for { for {
meta, err := metaReader.Read() meta, err := ReadMetadata(reader)
if err != nil { if err != nil {
if errors.Cause(err) != io.EOF { if errors.Cause(err) != io.EOF {
log.Trace(newError("failed to read metadata").Base(err)) log.Trace(newError("failed to read metadata").Base(err))
@ -263,7 +274,7 @@ type Server struct {
func NewServer(ctx context.Context) *Server { func NewServer(ctx context.Context) *Server {
s := &Server{} s := &Server{}
space := app.SpaceFromContext(ctx) space := app.SpaceFromContext(ctx)
space.OnInitialize(func() error { space.On(app.SpaceInitializing, func(interface{}) error {
d := dispatcher.FromSpace(space) d := dispatcher.FromSpace(space)
if d == nil { if d == nil {
return newError("no dispatcher in space") return newError("no dispatcher in space")
@ -304,14 +315,14 @@ func handle(ctx context.Context, s *Session, output buf.Writer) {
s.Close() s.Close()
} }
func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader io.Reader) error { func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error {
if meta.Option.Has(OptionData) { if meta.Option.Has(OptionData) {
return drain(reader) return drain(reader)
} }
return nil return nil
} }
func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, reader io.Reader) error { func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, reader *buf.BufferedReader) error {
log.Trace(newError("received request for ", meta.Target)) log.Trace(newError("received request for ", meta.Target))
inboundRay, err := w.dispatcher.Dispatch(ctx, meta.Target) inboundRay, err := w.dispatcher.Dispatch(ctx, meta.Target)
if err != nil { if err != nil {
@ -338,7 +349,7 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata,
return nil return nil
} }
func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader io.Reader) error { func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error {
if !meta.Option.Has(OptionData) { if !meta.Option.Has(OptionData) {
return nil return nil
} }
@ -348,7 +359,7 @@ func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader io.Reader) e
return drain(reader) return drain(reader)
} }
func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader io.Reader) error { func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error {
if s, found := w.sessionManager.Get(meta.SessionID); found { if s, found := w.sessionManager.Get(meta.SessionID); found {
s.Close() s.Close()
} }
@ -358,9 +369,8 @@ func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader io.Reader) er
return nil return nil
} }
func (w *ServerWorker) handleFrame(ctx context.Context, reader io.Reader) error { func (w *ServerWorker) handleFrame(ctx context.Context, reader *buf.BufferedReader) error {
metaReader := NewMetadataReader(reader) meta, err := ReadMetadata(reader)
meta, err := metaReader.Read()
if err != nil { if err != nil {
return newError("failed to read metadata").Base(err) return newError("failed to read metadata").Base(err)
} }
@ -386,7 +396,7 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader io.Reader) error
func (w *ServerWorker) run(ctx context.Context) { func (w *ServerWorker) run(ctx context.Context) {
input := w.outboundRay.OutboundInput() input := w.outboundRay.OutboundInput()
reader := buf.ToBytesReader(input) reader := buf.NewBufferedReader(input)
defer w.sessionManager.Close() defer w.sessionManager.Close()

View File

@ -9,14 +9,14 @@ import (
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
"v2ray.com/core/testing/assert"
"v2ray.com/core/transport/ray" "v2ray.com/core/transport/ray"
. "v2ray.com/ext/assert"
) )
func readAll(reader buf.Reader) (buf.MultiBuffer, error) { func readAll(reader buf.Reader) (buf.MultiBuffer, error) {
mb := buf.NewMultiBuffer() var mb buf.MultiBuffer
for { for {
b, err := reader.Read() b, err := reader.ReadMultiBuffer()
if err == io.EOF { if err == io.EOF {
break break
} }
@ -29,7 +29,7 @@ func readAll(reader buf.Reader) (buf.MultiBuffer, error) {
} }
func TestReaderWriter(t *testing.T) { func TestReaderWriter(t *testing.T) {
assert := assert.On(t) assert := With(t)
stream := ray.NewStream(context.Background()) stream := ray.NewStream(context.Background())
@ -45,98 +45,98 @@ func TestReaderWriter(t *testing.T) {
writePayload := func(writer *Writer, payload ...byte) error { writePayload := func(writer *Writer, payload ...byte) error {
b := buf.New() b := buf.New()
b.Append(payload) b.Append(payload)
return writer.Write(buf.NewMultiBufferValue(b)) return writer.WriteMultiBuffer(buf.NewMultiBufferValue(b))
} }
assert.Error(writePayload(writer, 'a', 'b', 'c', 'd')).IsNil() assert(writePayload(writer, 'a', 'b', 'c', 'd'), IsNil)
assert.Error(writePayload(writer2)).IsNil() assert(writePayload(writer2), IsNil)
assert.Error(writePayload(writer, 'e', 'f', 'g', 'h')).IsNil() assert(writePayload(writer, 'e', 'f', 'g', 'h'), IsNil)
assert.Error(writePayload(writer3, 'x')).IsNil() assert(writePayload(writer3, 'x'), IsNil)
writer.Close() writer.Close()
writer3.Close() writer3.Close()
assert.Error(writePayload(writer2, 'y')).IsNil() assert(writePayload(writer2, 'y'), IsNil)
writer2.Close() writer2.Close()
bytesReader := buf.ToBytesReader(stream) bytesReader := buf.NewBufferedReader(stream)
metaReader := NewMetadataReader(bytesReader)
streamReader := NewStreamReader(bytesReader) streamReader := NewStreamReader(bytesReader)
meta, err := metaReader.Read() meta, err := ReadMetadata(bytesReader)
assert.Error(err).IsNil() assert(err, IsNil)
assert.Uint16(meta.SessionID).Equals(1) assert(meta.SessionID, Equals, uint16(1))
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusNew)) assert(byte(meta.SessionStatus), Equals, byte(SessionStatusNew))
assert.Destination(meta.Target).Equals(dest) assert(meta.Target, Equals, dest)
assert.Byte(byte(meta.Option)).Equals(byte(OptionData)) assert(byte(meta.Option), Equals, byte(OptionData))
data, err := readAll(streamReader) data, err := readAll(streamReader)
assert.Error(err).IsNil() assert(err, IsNil)
assert.Int(len(data)).Equals(1) assert(len(data), Equals, 1)
assert.String(data[0].String()).Equals("abcd") assert(data[0].String(), Equals, "abcd")
meta, err = metaReader.Read() meta, err = ReadMetadata(bytesReader)
assert.Error(err).IsNil() assert(err, IsNil)
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusNew)) assert(byte(meta.SessionStatus), Equals, byte(SessionStatusNew))
assert.Uint16(meta.SessionID).Equals(2) assert(meta.SessionID, Equals, uint16(2))
assert.Byte(byte(meta.Option)).Equals(0) assert(byte(meta.Option), Equals, byte(0))
assert.Destination(meta.Target).Equals(dest2) assert(meta.Target, Equals, dest2)
meta, err = metaReader.Read() meta, err = ReadMetadata(bytesReader)
assert.Error(err).IsNil() assert(err, IsNil)
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusKeep)) assert(byte(meta.SessionStatus), Equals, byte(SessionStatusKeep))
assert.Uint16(meta.SessionID).Equals(1) assert(meta.SessionID, Equals, uint16(1))
assert.Byte(byte(meta.Option)).Equals(1) assert(byte(meta.Option), Equals, byte(1))
data, err = readAll(streamReader) data, err = readAll(streamReader)
assert.Error(err).IsNil() assert(err, IsNil)
assert.Int(len(data)).Equals(1) assert(len(data), Equals, 1)
assert.String(data[0].String()).Equals("efgh") assert(data[0].String(), Equals, "efgh")
meta, err = metaReader.Read() meta, err = ReadMetadata(bytesReader)
assert.Error(err).IsNil() assert(err, IsNil)
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusNew)) assert(byte(meta.SessionStatus), Equals, byte(SessionStatusNew))
assert.Uint16(meta.SessionID).Equals(3) assert(meta.SessionID, Equals, uint16(3))
assert.Byte(byte(meta.Option)).Equals(1) assert(byte(meta.Option), Equals, byte(1))
assert.Destination(meta.Target).Equals(dest3) assert(meta.Target, Equals, dest3)
data, err = readAll(streamReader) data, err = readAll(streamReader)
assert.Error(err).IsNil() assert(err, IsNil)
assert.Int(len(data)).Equals(1) assert(len(data), Equals, 1)
assert.String(data[0].String()).Equals("x") assert(data[0].String(), Equals, "x")
meta, err = metaReader.Read() meta, err = ReadMetadata(bytesReader)
assert.Error(err).IsNil() assert(err, IsNil)
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusEnd)) assert(byte(meta.SessionStatus), Equals, byte(SessionStatusEnd))
assert.Uint16(meta.SessionID).Equals(1) assert(meta.SessionID, Equals, uint16(1))
assert.Byte(byte(meta.Option)).Equals(0) assert(byte(meta.Option), Equals, byte(0))
meta, err = metaReader.Read() meta, err = ReadMetadata(bytesReader)
assert.Error(err).IsNil() assert(err, IsNil)
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusEnd)) assert(byte(meta.SessionStatus), Equals, byte(SessionStatusEnd))
assert.Uint16(meta.SessionID).Equals(3) assert(meta.SessionID, Equals, uint16(3))
assert.Byte(byte(meta.Option)).Equals(0) assert(byte(meta.Option), Equals, byte(0))
meta, err = metaReader.Read() meta, err = ReadMetadata(bytesReader)
assert.Error(err).IsNil() assert(err, IsNil)
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusKeep)) assert(byte(meta.SessionStatus), Equals, byte(SessionStatusKeep))
assert.Uint16(meta.SessionID).Equals(2) assert(meta.SessionID, Equals, uint16(2))
assert.Byte(byte(meta.Option)).Equals(1) assert(byte(meta.Option), Equals, byte(1))
data, err = readAll(streamReader) data, err = readAll(streamReader)
assert.Error(err).IsNil() assert(err, IsNil)
assert.Int(len(data)).Equals(1) assert(len(data), Equals, 1)
assert.String(data[0].String()).Equals("y") assert(data[0].String(), Equals, "y")
meta, err = metaReader.Read() meta, err = ReadMetadata(bytesReader)
assert.Error(err).IsNil() assert(err, IsNil)
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusEnd)) assert(byte(meta.SessionStatus), Equals, byte(SessionStatusEnd))
assert.Uint16(meta.SessionID).Equals(2) assert(meta.SessionID, Equals, uint16(2))
assert.Byte(byte(meta.Option)).Equals(0) assert(byte(meta.Option), Equals, byte(0))
stream.Close() stream.Close()
meta, err = metaReader.Read() meta, err = ReadMetadata(bytesReader)
assert.Error(err).IsNotNil() assert(err, IsNotNil)
assert(meta, IsNil)
} }

View File

@ -7,20 +7,9 @@ import (
"v2ray.com/core/common/serial" "v2ray.com/core/common/serial"
) )
type MetadataReader struct { // ReadMetadata reads FrameMetadata from the given reader.
reader io.Reader func ReadMetadata(reader io.Reader) (*FrameMetadata, error) {
buffer []byte metaLen, err := serial.ReadUint16(reader)
}
func NewMetadataReader(reader io.Reader) *MetadataReader {
return &MetadataReader{
reader: reader,
buffer: make([]byte, 1024),
}
}
func (r *MetadataReader) Read() (*FrameMetadata, error) {
metaLen, err := serial.ReadUint16(r.reader)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -28,17 +17,22 @@ func (r *MetadataReader) Read() (*FrameMetadata, error) {
return nil, newError("invalid metalen ", metaLen).AtWarning() return nil, newError("invalid metalen ", metaLen).AtWarning()
} }
if _, err := io.ReadFull(r.reader, r.buffer[:metaLen]); err != nil { b := buf.New()
defer b.Release()
if err := b.Reset(buf.ReadFullFrom(reader, int(metaLen))); err != nil {
return nil, err return nil, err
} }
return ReadFrameFrom(r.buffer) return ReadFrameFrom(b.Bytes())
} }
// PacketReader is an io.Reader that reads whole chunk of Mux frames every time.
type PacketReader struct { type PacketReader struct {
reader io.Reader reader io.Reader
eof bool eof bool
} }
// NewPacketReader creates a new PacketReader.
func NewPacketReader(reader io.Reader) *PacketReader { func NewPacketReader(reader io.Reader) *PacketReader {
return &PacketReader{ return &PacketReader{
reader: reader, reader: reader,
@ -46,7 +40,8 @@ func NewPacketReader(reader io.Reader) *PacketReader {
} }
} }
func (r *PacketReader) Read() (buf.MultiBuffer, error) { // ReadMultiBuffer implements buf.Reader.
func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
if r.eof { if r.eof {
return nil, io.EOF return nil, io.EOF
} }
@ -70,19 +65,22 @@ func (r *PacketReader) Read() (buf.MultiBuffer, error) {
return buf.NewMultiBufferValue(b), nil return buf.NewMultiBufferValue(b), nil
} }
// StreamReader reads Mux frame as a stream.
type StreamReader struct { type StreamReader struct {
reader io.Reader reader *buf.BufferedReader
leftOver int leftOver int
} }
func NewStreamReader(reader io.Reader) *StreamReader { // NewStreamReader creates a new StreamReader.
func NewStreamReader(reader *buf.BufferedReader) *StreamReader {
return &StreamReader{ return &StreamReader{
reader: reader, reader: reader,
leftOver: -1, leftOver: -1,
} }
} }
func (r *StreamReader) Read() (buf.MultiBuffer, error) { // ReadMultiBuffer implmenets buf.Reader.
func (r *StreamReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
if r.leftOver == 0 { if r.leftOver == 0 {
r.leftOver = -1 r.leftOver = -1
return nil, io.EOF return nil, io.EOF
@ -96,24 +94,7 @@ func (r *StreamReader) Read() (buf.MultiBuffer, error) {
r.leftOver = int(size) r.leftOver = int(size)
} }
mb := buf.NewMultiBuffer() mb, err := r.reader.ReadAtMost(r.leftOver)
for r.leftOver > 0 { r.leftOver -= mb.Len()
readLen := buf.Size return mb, err
if r.leftOver < readLen {
readLen = r.leftOver
}
b := buf.New()
if err := b.AppendSupplier(func(bb []byte) (int, error) {
return r.reader.Read(bb[:readLen])
}); err != nil {
mb.Release()
return nil, err
}
r.leftOver -= b.Len()
mb.Append(b)
if b.Len() < readLen {
break
}
}
return mb, nil
} }

View File

@ -1,7 +1,6 @@
package mux package mux
import ( import (
"io"
"sync" "sync"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
@ -19,7 +18,7 @@ type SessionManager struct {
func NewSessionManager() *SessionManager { func NewSessionManager() *SessionManager {
return &SessionManager{ return &SessionManager{
count: 0, count: 0,
sessions: make(map[uint16]*Session, 32), sessions: make(map[uint16]*Session, 16),
} }
} }
@ -58,6 +57,10 @@ func (m *SessionManager) Add(s *Session) {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
if m.closed {
return
}
m.sessions[s.ID] = s m.sessions[s.ID] = s
} }
@ -65,6 +68,10 @@ func (m *SessionManager) Remove(id uint16) {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
if m.closed {
return
}
delete(m.sessions, id) delete(m.sessions, id)
} }
@ -111,9 +118,10 @@ func (m *SessionManager) Close() {
s.output.Close() s.output.Close()
} }
m.sessions = make(map[uint16]*Session) m.sessions = nil
} }
// Session represents a client connection in a Mux connection.
type Session struct { type Session struct {
input ray.InputStream input ray.InputStream
output ray.OutputStream output ray.OutputStream
@ -122,13 +130,15 @@ type Session struct {
transferType protocol.TransferType transferType protocol.TransferType
} }
// Close closes all resources associated with this session.
func (s *Session) Close() { func (s *Session) Close() {
s.output.Close() s.output.Close()
s.input.Close() s.input.Close()
s.parent.Remove(s.ID) s.parent.Remove(s.ID)
} }
func (s *Session) NewReader(reader io.Reader) buf.Reader { // NewReader creates a buf.Reader based on the transfer type of this Session.
func (s *Session) NewReader(reader *buf.BufferedReader) buf.Reader {
if s.transferType == protocol.TransferTypeStream { if s.transferType == protocol.TransferTypeStream {
return NewStreamReader(reader) return NewStreamReader(reader)
} }

View File

@ -4,34 +4,36 @@ import (
"testing" "testing"
. "v2ray.com/core/app/proxyman/mux" . "v2ray.com/core/app/proxyman/mux"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func TestSessionManagerAdd(t *testing.T) { func TestSessionManagerAdd(t *testing.T) {
assert := assert.On(t) assert := With(t)
m := NewSessionManager() m := NewSessionManager()
s := m.Allocate() s := m.Allocate()
assert.Uint16(s.ID).Equals(1) assert(s.ID, Equals, uint16(1))
assert(m.Size(), Equals, 1)
s = m.Allocate() s = m.Allocate()
assert.Uint16(s.ID).Equals(2) assert(s.ID, Equals, uint16(2))
assert(m.Size(), Equals, 2)
s = &Session{ s = &Session{
ID: 4, ID: 4,
} }
m.Add(s) m.Add(s)
assert.Uint16(s.ID).Equals(4) assert(s.ID, Equals, uint16(4))
} }
func TestSessionManagerClose(t *testing.T) { func TestSessionManagerClose(t *testing.T) {
assert := assert.On(t) assert := With(t)
m := NewSessionManager() m := NewSessionManager()
s := m.Allocate() s := m.Allocate()
assert.Bool(m.CloseIfNoSession()).IsFalse() assert(m.CloseIfNoSession(), IsFalse)
m.Remove(s.ID) m.Remove(s.ID)
assert.Bool(m.CloseIfNoSession()).IsTrue() assert(m.CloseIfNoSession(), IsTrue)
} }

View File

@ -1,8 +1,7 @@
package mux package mux
import ( import (
"runtime" "v2ray.com/core/common"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
@ -10,9 +9,9 @@ import (
) )
type Writer struct { type Writer struct {
id uint16
dest net.Destination dest net.Destination
writer buf.Writer writer buf.Writer
id uint16
followup bool followup bool
transferType protocol.TransferType transferType protocol.TransferType
} }
@ -54,51 +53,47 @@ func (w *Writer) getNextFrameMeta() FrameMetadata {
func (w *Writer) writeMetaOnly() error { func (w *Writer) writeMetaOnly() error {
meta := w.getNextFrameMeta() meta := w.getNextFrameMeta()
b := buf.New() b := buf.New()
if err := b.AppendSupplier(meta.AsSupplier()); err != nil { if err := b.Reset(meta.AsSupplier()); err != nil {
return err return err
} }
runtime.KeepAlive(meta) return w.writer.WriteMultiBuffer(buf.NewMultiBufferValue(b))
return w.writer.Write(buf.NewMultiBufferValue(b))
} }
func (w *Writer) writeData(mb buf.MultiBuffer) error { func (w *Writer) writeData(mb buf.MultiBuffer) error {
meta := w.getNextFrameMeta() meta := w.getNextFrameMeta()
meta.Option.Add(OptionData) meta.Option.Set(OptionData)
frame := buf.New() frame := buf.New()
if err := frame.AppendSupplier(meta.AsSupplier()); err != nil { if err := frame.Reset(meta.AsSupplier()); err != nil {
return err return err
} }
runtime.KeepAlive(meta)
if err := frame.AppendSupplier(serial.WriteUint16(uint16(mb.Len()))); err != nil { if err := frame.AppendSupplier(serial.WriteUint16(uint16(mb.Len()))); err != nil {
return err return err
} }
mb2 := buf.NewMultiBuffer() mb2 := buf.NewMultiBufferCap(len(mb) + 1)
mb2.Append(frame) mb2.Append(frame)
mb2.AppendMulti(mb) mb2.AppendMulti(mb)
return w.writer.Write(mb2) return w.writer.WriteMultiBuffer(mb2)
} }
// Write implements buf.MultiBufferWriter. // WriteMultiBuffer implements buf.Writer.
func (w *Writer) Write(mb buf.MultiBuffer) error { func (w *Writer) WriteMultiBuffer(mb buf.MultiBuffer) error {
defer mb.Release()
if mb.IsEmpty() { if mb.IsEmpty() {
return w.writeMetaOnly() return w.writeMetaOnly()
} }
if w.transferType == protocol.TransferTypeStream {
const chunkSize = 8 * 1024
for !mb.IsEmpty() { for !mb.IsEmpty() {
slice := mb.SliceBySize(chunkSize) var chunk buf.MultiBuffer
if err := w.writeData(slice); err != nil { if w.transferType == protocol.TransferTypeStream {
return err chunk = mb.SliceBySize(8 * 1024)
}
}
} else { } else {
for _, b := range mb { chunk = buf.NewMultiBufferValue(mb.SplitFirst())
if err := w.writeData(buf.NewMultiBufferValue(b)); err != nil {
return err
} }
if err := w.writeData(chunk); err != nil {
return err
} }
} }
@ -112,8 +107,7 @@ func (w *Writer) Close() {
} }
frame := buf.New() frame := buf.New()
frame.AppendSupplier(meta.AsSupplier()) common.Must(frame.Reset(meta.AsSupplier()))
runtime.KeepAlive(meta)
w.writer.Write(buf.NewMultiBufferValue(frame)) w.writer.WriteMultiBuffer(buf.NewMultiBufferValue(frame))
} }

View File

@ -33,7 +33,7 @@ func NewHandler(ctx context.Context, config *proxyman.OutboundHandlerConfig) (*H
if space == nil { if space == nil {
return nil, newError("no space in context") return nil, newError("no space in context")
} }
space.OnInitialize(func() error { space.On(app.SpaceInitializing, func(interface{}) error {
ohm := proxyman.OutboundHandlerManagerFromSpace(space) ohm := proxyman.OutboundHandlerManagerFromSpace(space)
if ohm == nil { if ohm == nil {
return newError("no OutboundManager in space") return newError("no OutboundManager in space")
@ -78,6 +78,7 @@ func (h *Handler) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) {
err := h.mux.Dispatch(ctx, outboundRay) err := h.mux.Dispatch(ctx, outboundRay)
if err != nil { if err != nil {
log.Trace(newError("failed to process outbound traffic").Base(err)) log.Trace(newError("failed to process outbound traffic").Base(err))
outboundRay.OutboundOutput().CloseError()
} }
} else { } else {
err := h.proxy.Process(ctx, outboundRay, h) err := h.proxy.Process(ctx, outboundRay, h)
@ -122,8 +123,8 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (internet.Conn
} }
var ( var (
_ buf.MultiBufferReader = (*Connection)(nil) _ buf.Reader = (*Connection)(nil)
_ buf.MultiBufferWriter = (*Connection)(nil) _ buf.Writer = (*Connection)(nil)
) )
type Connection struct { type Connection struct {
@ -132,8 +133,7 @@ type Connection struct {
localAddr net.Addr localAddr net.Addr
remoteAddr net.Addr remoteAddr net.Addr
bytesReader io.Reader reader *buf.BufferedReader
reader buf.Reader
writer buf.Writer writer buf.Writer
} }
@ -148,8 +148,7 @@ func NewConnection(stream ray.Ray) *Connection {
IP: []byte{0, 0, 0, 0}, IP: []byte{0, 0, 0, 0},
Port: 0, Port: 0,
}, },
bytesReader: buf.ToBytesReader(stream.InboundOutput()), reader: buf.NewBufferedReader(stream.InboundOutput()),
reader: stream.InboundOutput(),
writer: stream.InboundInput(), writer: stream.InboundInput(),
} }
} }
@ -159,11 +158,11 @@ func (v *Connection) Read(b []byte) (int, error) {
if v.closed { if v.closed {
return 0, io.EOF return 0, io.EOF
} }
return v.bytesReader.Read(b) return v.reader.Read(b)
} }
func (v *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) { func (v *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) {
return v.reader.Read() return v.reader.ReadMultiBuffer()
} }
// Write implements net.Conn.Write(). // Write implements net.Conn.Write().
@ -171,14 +170,19 @@ func (v *Connection) Write(b []byte) (int, error) {
if v.closed { if v.closed {
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
} }
return buf.ToBytesWriter(v.writer).Write(b)
l := len(b)
mb := buf.NewMultiBufferCap(l/buf.Size + 1)
mb.Write(b)
return l, v.writer.WriteMultiBuffer(mb)
} }
func (v *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error { func (v *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
if v.closed { if v.closed {
return io.ErrClosedPipe return io.ErrClosedPipe
} }
return v.writer.Write(mb)
return v.writer.WriteMultiBuffer(mb)
} }
// Close implements net.Conn.Close(). // Close implements net.Conn.Close().

View File

@ -4,6 +4,8 @@ import (
"context" "context"
"regexp" "regexp"
"strings" "strings"
"sync"
"time"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
@ -64,13 +66,123 @@ func (v *AnyCondition) Len() int {
return len(*v) return len(*v)
} }
type PlainDomainMatcher string type timedResult struct {
timestamp time.Time
func NewPlainDomainMatcher(pattern string) Condition { result bool
return PlainDomainMatcher(pattern)
} }
func (v PlainDomainMatcher) Apply(ctx context.Context) bool { type CachableDomainMatcher struct {
sync.Mutex
matchers []domainMatcher
cache map[string]timedResult
lastScan time.Time
}
func NewCachableDomainMatcher() *CachableDomainMatcher {
return &CachableDomainMatcher{
matchers: make([]domainMatcher, 0, 64),
cache: make(map[string]timedResult, 512),
}
}
func (m *CachableDomainMatcher) Add(domain *Domain) error {
switch domain.Type {
case Domain_Plain:
m.matchers = append(m.matchers, NewPlainDomainMatcher(domain.Value))
case Domain_Regex:
rm, err := NewRegexpDomainMatcher(domain.Value)
if err != nil {
return err
}
m.matchers = append(m.matchers, rm)
case Domain_Domain:
m.matchers = append(m.matchers, NewSubDomainMatcher(domain.Value))
default:
return newError("unknown domain type: ", domain.Type).AtError()
}
return nil
}
func (m *CachableDomainMatcher) applyInternal(domain string) bool {
for _, matcher := range m.matchers {
if matcher.Apply(domain) {
return true
}
}
return false
}
type cacheResult int
const (
cacheMiss cacheResult = iota
cacheHitTrue
cacheHitFalse
)
func (m *CachableDomainMatcher) findInCache(domain string) cacheResult {
m.Lock()
defer m.Unlock()
r, f := m.cache[domain]
if !f {
return cacheMiss
}
r.timestamp = time.Now()
m.cache[domain] = r
if r.result {
return cacheHitTrue
}
return cacheHitFalse
}
func (m *CachableDomainMatcher) ApplyDomain(domain string) bool {
if len(m.matchers) < 64 {
return m.applyInternal(domain)
}
cr := m.findInCache(domain)
if cr == cacheHitTrue {
return true
}
if cr == cacheHitFalse {
return false
}
r := m.applyInternal(domain)
m.Lock()
defer m.Unlock()
m.cache[domain] = timedResult{
result: r,
timestamp: time.Now(),
}
now := time.Now()
if len(m.cache) > 256 && now.Sub(m.lastScan)/time.Second > 5 {
remove := make([]string, 0, 128)
now := time.Now()
for k, v := range m.cache {
if now.Sub(v.timestamp)/time.Second > 60 {
remove = append(remove, k)
}
}
for _, v := range remove {
delete(m.cache, v)
}
m.lastScan = now
}
return r
}
func (m *CachableDomainMatcher) Apply(ctx context.Context) bool {
dest, ok := proxy.TargetFromContext(ctx) dest, ok := proxy.TargetFromContext(ctx)
if !ok { if !ok {
return false return false
@ -79,7 +191,20 @@ func (v PlainDomainMatcher) Apply(ctx context.Context) bool {
if !dest.Address.Family().IsDomain() { if !dest.Address.Family().IsDomain() {
return false return false
} }
domain := dest.Address.Domain() return m.ApplyDomain(dest.Address.Domain())
}
type domainMatcher interface {
Apply(domain string) bool
}
type PlainDomainMatcher string
func NewPlainDomainMatcher(pattern string) PlainDomainMatcher {
return PlainDomainMatcher(pattern)
}
func (v PlainDomainMatcher) Apply(domain string) bool {
return strings.Contains(domain, string(v)) return strings.Contains(domain, string(v))
} }
@ -97,33 +222,17 @@ func NewRegexpDomainMatcher(pattern string) (*RegexpDomainMatcher, error) {
}, nil }, nil
} }
func (v *RegexpDomainMatcher) Apply(ctx context.Context) bool { func (v *RegexpDomainMatcher) Apply(domain string) bool {
dest, ok := proxy.TargetFromContext(ctx)
if !ok {
return false
}
if !dest.Address.Family().IsDomain() {
return false
}
domain := dest.Address.Domain()
return v.pattern.MatchString(strings.ToLower(domain)) return v.pattern.MatchString(strings.ToLower(domain))
} }
type SubDomainMatcher string type SubDomainMatcher string
func NewSubDomainMatcher(p string) Condition { func NewSubDomainMatcher(p string) SubDomainMatcher {
return SubDomainMatcher(p) return SubDomainMatcher(p)
} }
func (m SubDomainMatcher) Apply(ctx context.Context) bool { func (m SubDomainMatcher) Apply(domain string) bool {
dest, ok := proxy.TargetFromContext(ctx)
if !ok {
return false
}
if !dest.Address.Family().IsDomain() {
return false
}
domain := dest.Address.Domain()
pattern := string(m) pattern := string(m)
if !strings.HasSuffix(domain, pattern) { if !strings.HasSuffix(domain, pattern) {
return false return false
@ -149,8 +258,9 @@ func NewCIDRMatcher(ip []byte, mask uint32, onSource bool) (*CIDRMatcher, error)
func (v *CIDRMatcher) Apply(ctx context.Context) bool { func (v *CIDRMatcher) Apply(ctx context.Context) bool {
ips := make([]net.IP, 0, 4) ips := make([]net.IP, 0, 4)
if resolveIPs, ok := proxy.ResolvedIPsFromContext(ctx); ok { if resolver, ok := proxy.ResolvedIPsFromContext(ctx); ok {
for _, rip := range resolveIPs { resolvedIPs := resolver.Resolve()
for _, rip := range resolvedIPs {
if !rip.Family().IsIPv6() { if !rip.Family().IsIPv6() {
continue continue
} }
@ -192,8 +302,9 @@ func NewIPv4Matcher(ipnet *net.IPNetTable, onSource bool) *IPv4Matcher {
func (v *IPv4Matcher) Apply(ctx context.Context) bool { func (v *IPv4Matcher) Apply(ctx context.Context) bool {
ips := make([]net.IP, 0, 4) ips := make([]net.IP, 0, 4)
if resolveIPs, ok := proxy.ResolvedIPsFromContext(ctx); ok { if resolver, ok := proxy.ResolvedIPsFromContext(ctx); ok {
for _, rip := range resolveIPs { resolvedIPs := resolver.Resolve()
for _, rip := range resolvedIPs {
if !rip.Family().IsIPv4() { if !rip.Family().IsIPv4() {
continue continue
} }

View File

@ -2,57 +2,66 @@ package router_test
import ( import (
"context" "context"
"os"
"path/filepath"
"strconv"
"testing" "testing"
"time"
proto "github.com/golang/protobuf/proto"
. "v2ray.com/core/app/router" . "v2ray.com/core/app/router"
"v2ray.com/core/common"
"v2ray.com/core/common/errors"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/platform"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
"v2ray.com/core/proxy" "v2ray.com/core/proxy"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
"v2ray.com/ext/sysio"
) )
func TestSubDomainMatcher(t *testing.T) { func TestSubDomainMatcher(t *testing.T) {
assert := assert.On(t) assert := With(t)
cases := []struct { cases := []struct {
pattern string pattern string
input context.Context input string
output bool output bool
}{ }{
{ {
pattern: "v2ray.com", pattern: "v2ray.com",
input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("www.v2ray.com"), 80)), input: "www.v2ray.com",
output: true, output: true,
}, },
{ {
pattern: "v2ray.com", pattern: "v2ray.com",
input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("v2ray.com"), 80)), input: "v2ray.com",
output: true, output: true,
}, },
{ {
pattern: "v2ray.com", pattern: "v2ray.com",
input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("www.v3ray.com"), 80)), input: "www.v3ray.com",
output: false, output: false,
}, },
{ {
pattern: "v2ray.com", pattern: "v2ray.com",
input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("2ray.com"), 80)), input: "2ray.com",
output: false, output: false,
}, },
{ {
pattern: "v2ray.com", pattern: "v2ray.com",
input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("xv2ray.com"), 80)), input: "xv2ray.com",
output: false, output: false,
}, },
} }
for _, test := range cases { for _, test := range cases {
matcher := NewSubDomainMatcher(test.pattern) matcher := NewSubDomainMatcher(test.pattern)
assert.Bool(matcher.Apply(test.input) == test.output).IsTrue() assert(matcher.Apply(test.input) == test.output, IsTrue)
} }
} }
func TestRoutingRule(t *testing.T) { func TestRoutingRule(t *testing.T) {
assert := assert.On(t) assert := With(t)
type ruleTest struct { type ruleTest struct {
input context.Context input context.Context
@ -172,10 +181,56 @@ func TestRoutingRule(t *testing.T) {
for _, test := range cases { for _, test := range cases {
cond, err := test.rule.BuildCondition() cond, err := test.rule.BuildCondition()
assert.Error(err).IsNil() assert(err, IsNil)
for _, t := range test.test { for _, t := range test.test {
assert.Bool(cond.Apply(t.input)).Equals(t.output) assert(cond.Apply(t.input), Equals, t.output)
} }
} }
} }
func loadGeoSite(country string) ([]*Domain, error) {
geositeBytes, err := sysio.ReadAsset("geosite.dat")
if err != nil {
return nil, err
}
var geositeList GeoSiteList
if err := proto.Unmarshal(geositeBytes, &geositeList); err != nil {
return nil, err
}
for _, site := range geositeList.Entry {
if site.CountryCode == country {
return site.Domain, nil
}
}
return nil, errors.New("country not found: " + country)
}
func TestChinaSites(t *testing.T) {
assert := With(t)
common.Must(sysio.CopyFile(platform.GetAssetLocation("geosite.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "tools", "release", "config", "geosite.dat")))
domains, err := loadGeoSite("CN")
assert(err, IsNil)
matcher := NewCachableDomainMatcher()
for _, d := range domains {
assert(matcher.Add(d), IsNil)
}
assert(matcher.ApplyDomain("163.com"), IsTrue)
assert(matcher.ApplyDomain("163.com"), IsTrue)
assert(matcher.ApplyDomain("164.com"), IsFalse)
assert(matcher.ApplyDomain("164.com"), IsFalse)
for i := 0; i < 1024; i++ {
assert(matcher.ApplyDomain(strconv.Itoa(i)+".not-exists.com"), IsFalse)
}
time.Sleep(time.Second * 10)
for i := 0; i < 1024; i++ {
assert(matcher.ApplyDomain(strconv.Itoa(i)+".not-exists2.com"), IsFalse)
}
}

View File

@ -52,24 +52,11 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
conds := NewConditionChan() conds := NewConditionChan()
if len(rr.Domain) > 0 { if len(rr.Domain) > 0 {
anyCond := NewAnyCondition() matcher := NewCachableDomainMatcher()
for _, domain := range rr.Domain { for _, domain := range rr.Domain {
switch domain.Type { matcher.Add(domain)
case Domain_Plain:
anyCond.Add(NewPlainDomainMatcher(domain.Value))
case Domain_Regex:
matcher, err := NewRegexpDomainMatcher(domain.Value)
if err != nil {
return nil, err
} }
anyCond.Add(matcher) conds.Add(matcher)
case Domain_Domain:
anyCond.Add(NewSubDomainMatcher(domain.Value))
default:
panic("Unknown domain type.")
}
}
conds.Add(anyCond)
} }
if len(rr.Cidr) > 0 { if len(rr.Cidr) > 0 {

View File

@ -54,23 +54,27 @@ const (
Config_UseIp Config_DomainStrategy = 1 Config_UseIp Config_DomainStrategy = 1
// Resolve to IP if the domain doesn't match any rules. // Resolve to IP if the domain doesn't match any rules.
Config_IpIfNonMatch Config_DomainStrategy = 2 Config_IpIfNonMatch Config_DomainStrategy = 2
// Resolve to IP if any rule requires IP matching.
Config_IpOnDemand Config_DomainStrategy = 3
) )
var Config_DomainStrategy_name = map[int32]string{ var Config_DomainStrategy_name = map[int32]string{
0: "AsIs", 0: "AsIs",
1: "UseIp", 1: "UseIp",
2: "IpIfNonMatch", 2: "IpIfNonMatch",
3: "IpOnDemand",
} }
var Config_DomainStrategy_value = map[string]int32{ var Config_DomainStrategy_value = map[string]int32{
"AsIs": 0, "AsIs": 0,
"UseIp": 1, "UseIp": 1,
"IpIfNonMatch": 2, "IpIfNonMatch": 2,
"IpOnDemand": 3,
} }
func (x Config_DomainStrategy) String() string { func (x Config_DomainStrategy) String() string {
return proto.EnumName(Config_DomainStrategy_name, int32(x)) return proto.EnumName(Config_DomainStrategy_name, int32(x))
} }
func (Config_DomainStrategy) EnumDescriptor() ([]byte, []int) { return fileDescriptor0, []int{3, 0} } func (Config_DomainStrategy) EnumDescriptor() ([]byte, []int) { return fileDescriptor0, []int{7, 0} }
// Domain for routing decision. // Domain for routing decision.
type Domain struct { type Domain struct {
@ -126,6 +130,86 @@ func (m *CIDR) GetPrefix() uint32 {
return 0 return 0
} }
type GeoIP struct {
CountryCode string `protobuf:"bytes,1,opt,name=country_code,json=countryCode" json:"country_code,omitempty"`
Cidr []*CIDR `protobuf:"bytes,2,rep,name=cidr" json:"cidr,omitempty"`
}
func (m *GeoIP) Reset() { *m = GeoIP{} }
func (m *GeoIP) String() string { return proto.CompactTextString(m) }
func (*GeoIP) ProtoMessage() {}
func (*GeoIP) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
func (m *GeoIP) GetCountryCode() string {
if m != nil {
return m.CountryCode
}
return ""
}
func (m *GeoIP) GetCidr() []*CIDR {
if m != nil {
return m.Cidr
}
return nil
}
type GeoIPList struct {
Entry []*GeoIP `protobuf:"bytes,1,rep,name=entry" json:"entry,omitempty"`
}
func (m *GeoIPList) Reset() { *m = GeoIPList{} }
func (m *GeoIPList) String() string { return proto.CompactTextString(m) }
func (*GeoIPList) ProtoMessage() {}
func (*GeoIPList) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
func (m *GeoIPList) GetEntry() []*GeoIP {
if m != nil {
return m.Entry
}
return nil
}
type GeoSite struct {
CountryCode string `protobuf:"bytes,1,opt,name=country_code,json=countryCode" json:"country_code,omitempty"`
Domain []*Domain `protobuf:"bytes,2,rep,name=domain" json:"domain,omitempty"`
}
func (m *GeoSite) Reset() { *m = GeoSite{} }
func (m *GeoSite) String() string { return proto.CompactTextString(m) }
func (*GeoSite) ProtoMessage() {}
func (*GeoSite) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} }
func (m *GeoSite) GetCountryCode() string {
if m != nil {
return m.CountryCode
}
return ""
}
func (m *GeoSite) GetDomain() []*Domain {
if m != nil {
return m.Domain
}
return nil
}
type GeoSiteList struct {
Entry []*GeoSite `protobuf:"bytes,1,rep,name=entry" json:"entry,omitempty"`
}
func (m *GeoSiteList) Reset() { *m = GeoSiteList{} }
func (m *GeoSiteList) String() string { return proto.CompactTextString(m) }
func (*GeoSiteList) ProtoMessage() {}
func (*GeoSiteList) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} }
func (m *GeoSiteList) GetEntry() []*GeoSite {
if m != nil {
return m.Entry
}
return nil
}
type RoutingRule struct { type RoutingRule struct {
Tag string `protobuf:"bytes,1,opt,name=tag" json:"tag,omitempty"` Tag string `protobuf:"bytes,1,opt,name=tag" json:"tag,omitempty"`
Domain []*Domain `protobuf:"bytes,2,rep,name=domain" json:"domain,omitempty"` Domain []*Domain `protobuf:"bytes,2,rep,name=domain" json:"domain,omitempty"`
@ -140,7 +224,7 @@ type RoutingRule struct {
func (m *RoutingRule) Reset() { *m = RoutingRule{} } func (m *RoutingRule) Reset() { *m = RoutingRule{} }
func (m *RoutingRule) String() string { return proto.CompactTextString(m) } func (m *RoutingRule) String() string { return proto.CompactTextString(m) }
func (*RoutingRule) ProtoMessage() {} func (*RoutingRule) ProtoMessage() {}
func (*RoutingRule) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} } func (*RoutingRule) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} }
func (m *RoutingRule) GetTag() string { func (m *RoutingRule) GetTag() string {
if m != nil { if m != nil {
@ -206,7 +290,7 @@ type Config struct {
func (m *Config) Reset() { *m = Config{} } func (m *Config) Reset() { *m = Config{} }
func (m *Config) String() string { return proto.CompactTextString(m) } func (m *Config) String() string { return proto.CompactTextString(m) }
func (*Config) ProtoMessage() {} func (*Config) ProtoMessage() {}
func (*Config) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} } func (*Config) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} }
func (m *Config) GetDomainStrategy() Config_DomainStrategy { func (m *Config) GetDomainStrategy() Config_DomainStrategy {
if m != nil { if m != nil {
@ -225,6 +309,10 @@ func (m *Config) GetRule() []*RoutingRule {
func init() { func init() {
proto.RegisterType((*Domain)(nil), "v2ray.core.app.router.Domain") proto.RegisterType((*Domain)(nil), "v2ray.core.app.router.Domain")
proto.RegisterType((*CIDR)(nil), "v2ray.core.app.router.CIDR") proto.RegisterType((*CIDR)(nil), "v2ray.core.app.router.CIDR")
proto.RegisterType((*GeoIP)(nil), "v2ray.core.app.router.GeoIP")
proto.RegisterType((*GeoIPList)(nil), "v2ray.core.app.router.GeoIPList")
proto.RegisterType((*GeoSite)(nil), "v2ray.core.app.router.GeoSite")
proto.RegisterType((*GeoSiteList)(nil), "v2ray.core.app.router.GeoSiteList")
proto.RegisterType((*RoutingRule)(nil), "v2ray.core.app.router.RoutingRule") proto.RegisterType((*RoutingRule)(nil), "v2ray.core.app.router.RoutingRule")
proto.RegisterType((*Config)(nil), "v2ray.core.app.router.Config") proto.RegisterType((*Config)(nil), "v2ray.core.app.router.Config")
proto.RegisterEnum("v2ray.core.app.router.Domain_Type", Domain_Type_name, Domain_Type_value) proto.RegisterEnum("v2ray.core.app.router.Domain_Type", Domain_Type_name, Domain_Type_value)
@ -234,39 +322,45 @@ func init() {
func init() { proto.RegisterFile("v2ray.com/core/app/router/config.proto", fileDescriptor0) } func init() { proto.RegisterFile("v2ray.com/core/app/router/config.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{ var fileDescriptor0 = []byte{
// 538 bytes of a gzipped FileDescriptorProto // 640 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x93, 0xc1, 0x6e, 0xd4, 0x30, 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x94, 0xcd, 0x6e, 0xd4, 0x3a,
0x10, 0x86, 0x49, 0x76, 0x1b, 0xba, 0x93, 0xb2, 0x44, 0x16, 0x45, 0xa1, 0xa8, 0x22, 0x8a, 0x10, 0x14, 0xc7, 0x6f, 0xe6, 0xab, 0x9d, 0x93, 0xb9, 0x73, 0x23, 0xeb, 0x16, 0x0d, 0x85, 0xc2, 0x10,
0xe4, 0x80, 0x12, 0x69, 0x11, 0x70, 0x01, 0xa1, 0xb2, 0xed, 0x61, 0x25, 0xa8, 0x2a, 0xd3, 0x72, 0x21, 0x98, 0x05, 0x4a, 0xa4, 0xe1, 0x63, 0x05, 0xaa, 0xca, 0xb4, 0xaa, 0x22, 0x41, 0x19, 0xb9,
0xe0, 0x12, 0xb9, 0x59, 0x37, 0x58, 0x24, 0xb6, 0xe5, 0x38, 0xa5, 0x7b, 0xe3, 0x05, 0x78, 0x11, 0x2d, 0x0b, 0x58, 0x44, 0x6e, 0xe2, 0x86, 0x88, 0x89, 0x6d, 0x39, 0x4e, 0xe9, 0xec, 0x78, 0x01,
0x9e, 0x86, 0x47, 0x42, 0xb6, 0x53, 0xd1, 0xa2, 0x2e, 0xdc, 0x66, 0x9c, 0xef, 0x9f, 0x19, 0x8f, 0x5e, 0x84, 0xa7, 0xe2, 0x51, 0x90, 0xed, 0x0c, 0xb4, 0xa8, 0x81, 0x8a, 0x9d, 0xed, 0xfc, 0xfe,
0xff, 0xc0, 0x93, 0xf3, 0x99, 0x22, 0xab, 0xbc, 0x12, 0x6d, 0x51, 0x09, 0x45, 0x0b, 0x22, 0x65, 0xe7, 0xfc, 0x73, 0x7c, 0x8e, 0xe1, 0xc1, 0xd9, 0x54, 0x92, 0x65, 0x90, 0xf0, 0x22, 0x4c, 0xb8,
0xa1, 0x44, 0xaf, 0xa9, 0x2a, 0x2a, 0xc1, 0xcf, 0x58, 0x9d, 0x4b, 0x25, 0xb4, 0x40, 0xdb, 0x97, 0xa4, 0x21, 0x11, 0x22, 0x94, 0xbc, 0x52, 0x54, 0x86, 0x09, 0x67, 0xa7, 0x79, 0x16, 0x08, 0xc9,
0x9c, 0xa2, 0x39, 0x91, 0x32, 0x77, 0xcc, 0xce, 0xe3, 0xbf, 0xe4, 0x95, 0x68, 0x5b, 0xc1, 0x0b, 0x15, 0x47, 0x1b, 0x2b, 0x4e, 0xd2, 0x80, 0x08, 0x11, 0x58, 0x66, 0xf3, 0xfe, 0x2f, 0xf2, 0x84,
0x4e, 0x75, 0x21, 0x85, 0xd2, 0x4e, 0xbc, 0xf3, 0x74, 0x3d, 0xc5, 0xa9, 0xfe, 0x26, 0xd4, 0x57, 0x17, 0x05, 0x67, 0x21, 0xa3, 0x2a, 0x14, 0x5c, 0x2a, 0x2b, 0xde, 0x7c, 0xd8, 0x4c, 0x31, 0xaa,
0x07, 0xa6, 0xdf, 0x3d, 0x08, 0xf6, 0x45, 0x4b, 0x18, 0x47, 0x2f, 0x61, 0xac, 0x57, 0x92, 0xc6, 0x3e, 0x71, 0xf9, 0xd1, 0x82, 0xfe, 0x67, 0x07, 0x7a, 0xbb, 0xbc, 0x20, 0x39, 0x43, 0xcf, 0xa0,
0x5e, 0xe2, 0x65, 0xd3, 0x59, 0x9a, 0xdf, 0xd8, 0x3f, 0x77, 0x70, 0x7e, 0xbc, 0x92, 0x14, 0x5b, 0xa3, 0x96, 0x82, 0x8e, 0x9c, 0xb1, 0x33, 0x19, 0x4e, 0xfd, 0xe0, 0xca, 0xfc, 0x81, 0x85, 0x83,
0x1e, 0xdd, 0x83, 0x8d, 0x73, 0xd2, 0xf4, 0x34, 0xf6, 0x13, 0x2f, 0x9b, 0x60, 0x97, 0xa4, 0x19, 0xa3, 0xa5, 0xa0, 0xd8, 0xf0, 0xe8, 0x7f, 0xe8, 0x9e, 0x91, 0x45, 0x45, 0x47, 0xad, 0xb1, 0x33,
0x8c, 0x0d, 0x83, 0x26, 0xb0, 0x71, 0xd4, 0x10, 0xc6, 0xa3, 0x5b, 0x26, 0xc4, 0xb4, 0xa6, 0x17, 0xe9, 0x63, 0xbb, 0xf1, 0x27, 0xd0, 0xd1, 0x0c, 0xea, 0x43, 0x77, 0xbe, 0x20, 0x39, 0xf3, 0xfe,
0x91, 0x87, 0xe0, 0xb2, 0x6b, 0xe4, 0xa7, 0x39, 0x8c, 0xe7, 0x8b, 0x7d, 0x8c, 0xa6, 0xe0, 0x33, 0xd1, 0x4b, 0x4c, 0x33, 0x7a, 0xee, 0x39, 0x08, 0x56, 0x59, 0xbd, 0x96, 0x1f, 0x40, 0x67, 0x16,
0x69, 0xbb, 0x6f, 0x61, 0x9f, 0x49, 0x74, 0x1f, 0x02, 0xa9, 0xe8, 0x19, 0xbb, 0xb0, 0x85, 0xef, 0xed, 0x62, 0x34, 0x84, 0x56, 0x2e, 0x4c, 0xf6, 0x01, 0x6e, 0xe5, 0x02, 0xdd, 0x80, 0x9e, 0x90,
0xe0, 0x21, 0x4b, 0x7f, 0x8c, 0x20, 0xc4, 0xa2, 0xd7, 0x8c, 0xd7, 0xb8, 0x6f, 0x28, 0x8a, 0x60, 0xf4, 0x34, 0x3f, 0x37, 0x81, 0xff, 0xc5, 0xf5, 0xce, 0x7f, 0x0f, 0xdd, 0x7d, 0xca, 0xa3, 0x39,
0xa4, 0x49, 0x6d, 0x85, 0x13, 0x6c, 0x42, 0xf4, 0x02, 0x82, 0xa5, 0xad, 0x1e, 0xfb, 0xc9, 0x28, 0xba, 0x07, 0x83, 0x84, 0x57, 0x4c, 0xc9, 0x65, 0x9c, 0xf0, 0xd4, 0x1a, 0xef, 0x63, 0xb7, 0x3e,
0x0b, 0x67, 0xbb, 0xff, 0xbc, 0x0b, 0x1e, 0x60, 0x54, 0xc0, 0xb8, 0x62, 0x4b, 0x15, 0x8f, 0xac, 0x9b, 0xf1, 0x94, 0xa2, 0x10, 0x3a, 0x49, 0x9e, 0xca, 0x51, 0x6b, 0xdc, 0x9e, 0xb8, 0xd3, 0x5b,
0xe8, 0xe1, 0x1a, 0x91, 0x99, 0x15, 0x5b, 0x10, 0xbd, 0x05, 0x30, 0x3b, 0x2f, 0x15, 0xe1, 0x35, 0x0d, 0xff, 0xa4, 0xd3, 0x63, 0x03, 0xfa, 0xdb, 0xd0, 0x37, 0xc1, 0x5f, 0xe5, 0xa5, 0x42, 0x53,
0x8d, 0xc7, 0x89, 0x97, 0x85, 0xb3, 0xe4, 0xaa, 0xcc, 0xad, 0x3d, 0xe7, 0x54, 0xe7, 0x47, 0x42, 0xe8, 0x52, 0x1d, 0x6a, 0xe4, 0x18, 0xf9, 0xed, 0x06, 0xb9, 0x11, 0x60, 0x8b, 0xfa, 0x09, 0xac,
0x69, 0x6c, 0x38, 0x3c, 0x91, 0x97, 0x21, 0x3a, 0x80, 0xad, 0xe1, 0x39, 0xca, 0x86, 0x75, 0x3a, 0xed, 0x53, 0x7e, 0x98, 0x2b, 0x7a, 0x1d, 0x7f, 0x4f, 0xa1, 0x97, 0x9a, 0x3a, 0xd4, 0x0e, 0xb7,
0xde, 0xb0, 0x25, 0xd2, 0x35, 0x25, 0x0e, 0x1d, 0xfa, 0x9e, 0x75, 0x1a, 0x87, 0xfc, 0x4f, 0x82, 0x7e, 0x5b, 0x75, 0x5c, 0xc3, 0xfe, 0x0c, 0xdc, 0x3a, 0x89, 0xf1, 0xf9, 0xe4, 0xb2, 0xcf, 0x3b,
0x5e, 0x43, 0xd8, 0x89, 0x5e, 0x55, 0xb4, 0xb4, 0xf3, 0x07, 0xff, 0x9f, 0x1f, 0x1c, 0x3f, 0x37, 0xcd, 0x3e, 0xb5, 0x64, 0xe5, 0xf4, 0x4b, 0x1b, 0x5c, 0xcc, 0x2b, 0x95, 0xb3, 0x0c, 0x57, 0x0b,
0xb7, 0xd8, 0x05, 0xe8, 0x3b, 0xaa, 0x4a, 0xda, 0x12, 0xd6, 0xc4, 0xb7, 0x93, 0x51, 0x36, 0xc1, 0x8a, 0x3c, 0x68, 0x2b, 0x92, 0xd5, 0x2e, 0xf5, 0xf2, 0x2f, 0xdd, 0xfd, 0x28, 0x7a, 0xfb, 0x9a,
0x13, 0x73, 0x72, 0x60, 0x0e, 0xd0, 0x23, 0x08, 0x19, 0x3f, 0x15, 0x3d, 0x5f, 0x96, 0x66, 0xcd, 0x45, 0x47, 0xdb, 0x00, 0xba, 0x77, 0x63, 0x49, 0x58, 0x46, 0x47, 0x9d, 0xb1, 0x33, 0x71, 0xa7,
0x9b, 0xf6, 0x3b, 0x0c, 0x47, 0xc7, 0xa4, 0x4e, 0x7f, 0x79, 0x10, 0xcc, 0xad, 0x73, 0xd1, 0x09, 0xe3, 0x8b, 0x32, 0xdb, 0xbe, 0x01, 0xa3, 0x2a, 0x98, 0x73, 0xa9, 0xb0, 0xe6, 0x70, 0x5f, 0xac,
0xdc, 0x75, 0xbb, 0x2c, 0x3b, 0xad, 0x88, 0xa6, 0xf5, 0x6a, 0x70, 0xd3, 0xb3, 0x75, 0xc3, 0x38, 0x96, 0x68, 0x0f, 0x06, 0x75, 0x5b, 0xc7, 0x8b, 0xbc, 0x54, 0xa3, 0xae, 0x09, 0xe1, 0x37, 0x84,
0xc7, 0xbb, 0x87, 0xf8, 0x38, 0x68, 0xf0, 0x74, 0x79, 0x2d, 0x37, 0xce, 0x54, 0x7d, 0x43, 0x87, 0x38, 0xb0, 0xa8, 0x2e, 0x1d, 0x76, 0xd9, 0xcf, 0x0d, 0x7a, 0x0e, 0x6e, 0xc9, 0x2b, 0x99, 0xd0,
0xd7, 0x5c, 0xe7, 0xcc, 0x2b, 0x9e, 0xc0, 0x96, 0x4f, 0x5f, 0xc1, 0xf4, 0x7a, 0x65, 0xb4, 0x09, 0xd8, 0xf8, 0xef, 0xfd, 0xd9, 0x3f, 0x58, 0x7e, 0xa6, 0xff, 0x62, 0x0b, 0xa0, 0x2a, 0xa9, 0x8c,
0xe3, 0xbd, 0x6e, 0xd1, 0x39, 0x33, 0x9e, 0x74, 0x74, 0x21, 0x23, 0x0f, 0x45, 0xb0, 0xb5, 0x90, 0x69, 0x41, 0xf2, 0xc5, 0x68, 0x6d, 0xdc, 0x9e, 0xf4, 0x71, 0x5f, 0x9f, 0xec, 0xe9, 0x03, 0x74,
0x8b, 0xb3, 0x43, 0xc1, 0x3f, 0x10, 0x5d, 0x7d, 0x89, 0xfc, 0x77, 0x6f, 0xe0, 0x41, 0x25, 0xda, 0x17, 0xdc, 0x9c, 0x9d, 0xf0, 0x8a, 0xa5, 0xb1, 0x2e, 0xf3, 0xba, 0xf9, 0x0e, 0xf5, 0xd1, 0x11,
0x9b, 0xfb, 0x1c, 0x79, 0x9f, 0x03, 0x17, 0xfd, 0xf4, 0xb7, 0x3f, 0xcd, 0x30, 0x59, 0xe5, 0x73, 0xc9, 0xfc, 0x6f, 0x0e, 0xf4, 0x66, 0xe6, 0x05, 0x40, 0xc7, 0xf0, 0x9f, 0xad, 0x65, 0x5c, 0x2a,
0x43, 0xec, 0x49, 0x69, 0x47, 0xa0, 0xea, 0x34, 0xb0, 0xff, 0xd6, 0xf3, 0xdf, 0x01, 0x00, 0x00, 0x49, 0x14, 0xcd, 0x96, 0xf5, 0x54, 0x3e, 0x6a, 0x32, 0x63, 0x5f, 0x0e, 0x7b, 0x11, 0x87, 0xb5,
0xff, 0xff, 0xa7, 0x6a, 0x97, 0x93, 0xeb, 0x03, 0x00, 0x00, 0x06, 0x0f, 0xd3, 0x4b, 0x7b, 0x3d, 0xe1, 0xb2, 0x5a, 0xd0, 0xfa, 0x36, 0x9b, 0x26, 0xfc, 0x42,
0x4f, 0x60, 0xc3, 0xfb, 0xfb, 0x30, 0xbc, 0x1c, 0x19, 0xad, 0x43, 0x67, 0xa7, 0x8c, 0x4a, 0x3b,
0xd4, 0xc7, 0x25, 0x8d, 0x84, 0xe7, 0x20, 0x0f, 0x06, 0x91, 0x88, 0x4e, 0x0f, 0x38, 0x7b, 0x4d,
0x54, 0xf2, 0xc1, 0x6b, 0xa1, 0x21, 0x40, 0x24, 0xde, 0xb0, 0x5d, 0x5a, 0x10, 0x96, 0x7a, 0xed,
0x97, 0x2f, 0xe0, 0x66, 0xc2, 0x8b, 0xab, 0xf3, 0xce, 0x9d, 0x77, 0x3d, 0xbb, 0xfa, 0xda, 0xda,
0x78, 0x3b, 0xc5, 0x64, 0x19, 0xcc, 0x34, 0xb1, 0x23, 0x84, 0xb1, 0x44, 0xe5, 0x49, 0xcf, 0xbc,
0x59, 0x8f, 0xbf, 0x07, 0x00, 0x00, 0xff, 0xff, 0x53, 0x7c, 0xa8, 0x94, 0x43, 0x05, 0x00, 0x00,
} }

View File

@ -37,6 +37,24 @@ message CIDR {
uint32 prefix = 2; uint32 prefix = 2;
} }
message GeoIP {
string country_code = 1;
repeated CIDR cidr = 2;
}
message GeoIPList {
repeated GeoIP entry = 1;
}
message GeoSite {
string country_code = 1;
repeated Domain domain = 2;
}
message GeoSiteList{
repeated GeoSite entry = 1;
}
message RoutingRule { message RoutingRule {
string tag = 1; string tag = 1;
repeated Domain domain = 2; repeated Domain domain = 2;
@ -58,6 +76,9 @@ message Config {
// Resolve to IP if the domain doesn't match any rules. // Resolve to IP if the domain doesn't match any rules.
IpIfNonMatch = 2; IpIfNonMatch = 2;
// Resolve to IP if any rule requires IP matching.
IpOnDemand = 3;
} }
DomainStrategy domain_strategy = 1; DomainStrategy domain_strategy = 1;
repeated RoutingRule rule = 2; repeated RoutingRule rule = 2;

View File

@ -33,7 +33,7 @@ func NewRouter(ctx context.Context, config *Config) (*Router, error) {
rules: make([]Rule, len(config.Rule)), rules: make([]Rule, len(config.Rule)),
} }
space.OnInitialize(func() error { space.On(app.SpaceInitializing, func(interface{}) error {
for idx, rule := range config.Rule { for idx, rule := range config.Rule {
r.rules[idx].Tag = rule.Tag r.rules[idx].Tag = rule.Tag
cond, err := rule.BuildCondition() cond, err := rule.BuildCondition()
@ -52,19 +52,42 @@ func NewRouter(ctx context.Context, config *Config) (*Router, error) {
return r, nil return r, nil
} }
func (r *Router) resolveIP(dest net.Destination) []net.Address { type ipResolver struct {
ips := r.dnsServer.Get(dest.Address.Domain()) ip []net.Address
domain string
resolved bool
dnsServer dns.Server
}
func (r *ipResolver) Resolve() []net.Address {
if r.resolved {
return r.ip
}
log.Trace(newError("looking for IP for domain: ", r.domain))
r.resolved = true
ips := r.dnsServer.Get(r.domain)
if len(ips) == 0 { if len(ips) == 0 {
return nil return nil
} }
dests := make([]net.Address, len(ips)) r.ip = make([]net.Address, len(ips))
for idx, ip := range ips { for i, ip := range ips {
dests[idx] = net.IPAddress(ip) r.ip[i] = net.IPAddress(ip)
} }
return dests return r.ip
} }
func (r *Router) TakeDetour(ctx context.Context) (string, error) { func (r *Router) TakeDetour(ctx context.Context) (string, error) {
resolver := &ipResolver{
dnsServer: r.dnsServer,
}
if r.domainStrategy == Config_IpOnDemand {
if dest, ok := proxy.TargetFromContext(ctx); ok && dest.Address.Family().IsDomain() {
resolver.domain = dest.Address.Domain()
ctx = proxy.ContextWithResolveIPs(ctx, resolver)
}
}
for _, rule := range r.rules { for _, rule := range r.rules {
if rule.Apply(ctx) { if rule.Apply(ctx) {
return rule.Tag, nil return rule.Tag, nil
@ -77,10 +100,10 @@ func (r *Router) TakeDetour(ctx context.Context) (string, error) {
} }
if r.domainStrategy == Config_IpIfNonMatch && dest.Address.Family().IsDomain() { if r.domainStrategy == Config_IpIfNonMatch && dest.Address.Family().IsDomain() {
log.Trace(newError("looking up IP for ", dest)) resolver.domain = dest.Address.Domain()
ipDests := r.resolveIP(dest) ips := resolver.Resolve()
if ipDests != nil { if len(ips) > 0 {
ctx = proxy.ContextWithResolveIPs(ctx, ipDests) ctx = proxy.ContextWithResolveIPs(ctx, resolver)
for _, rule := range r.rules { for _, rule := range r.rules {
if rule.Apply(ctx) { if rule.Apply(ctx) {
return rule.Tag, nil return rule.Tag, nil

View File

@ -14,11 +14,11 @@ import (
. "v2ray.com/core/app/router" . "v2ray.com/core/app/router"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/proxy" "v2ray.com/core/proxy"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func TestSimpleRouter(t *testing.T) { func TestSimpleRouter(t *testing.T) {
assert := assert.On(t) assert := With(t)
config := &Config{ config := &Config{
Rule: []*RoutingRule{ Rule: []*RoutingRule{
@ -33,16 +33,16 @@ func TestSimpleRouter(t *testing.T) {
space := app.NewSpace() space := app.NewSpace()
ctx := app.ContextWithSpace(context.Background(), space) ctx := app.ContextWithSpace(context.Background(), space)
assert.Error(app.AddApplicationToSpace(ctx, new(dns.Config))).IsNil() assert(app.AddApplicationToSpace(ctx, new(dns.Config)), IsNil)
assert.Error(app.AddApplicationToSpace(ctx, new(dispatcher.Config))).IsNil() assert(app.AddApplicationToSpace(ctx, new(dispatcher.Config)), IsNil)
assert.Error(app.AddApplicationToSpace(ctx, new(proxyman.OutboundConfig))).IsNil() assert(app.AddApplicationToSpace(ctx, new(proxyman.OutboundConfig)), IsNil)
assert.Error(app.AddApplicationToSpace(ctx, config)).IsNil() assert(app.AddApplicationToSpace(ctx, config), IsNil)
assert.Error(space.Initialize()).IsNil() assert(space.Initialize(), IsNil)
r := FromSpace(space) r := FromSpace(space)
ctx = proxy.ContextWithTarget(ctx, net.TCPDestination(net.DomainAddress("v2ray.com"), 80)) ctx = proxy.ContextWithTarget(ctx, net.TCPDestination(net.DomainAddress("v2ray.com"), 80))
tag, err := r.TakeDetour(ctx) tag, err := r.TakeDetour(ctx)
assert.Error(err).IsNil() assert(err, IsNil)
assert.String(tag).Equals("test") assert(tag, Equals, "test")
} }

View File

@ -5,6 +5,7 @@ import (
"reflect" "reflect"
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/event"
) )
type Application interface { type Application interface {
@ -13,8 +14,6 @@ type Application interface {
Close() Close()
} }
type InitializationCallback func() error
func CreateAppFromConfig(ctx context.Context, config interface{}) (Application, error) { func CreateAppFromConfig(ctx context.Context, config interface{}) (Application, error) {
application, err := common.CreateObject(ctx, config) application, err := common.CreateObject(ctx, config)
if err != nil { if err != nil {
@ -29,47 +28,48 @@ func CreateAppFromConfig(ctx context.Context, config interface{}) (Application,
} }
// A Space contains all apps that may be available in a V2Ray runtime. // A Space contains all apps that may be available in a V2Ray runtime.
// Caller must check the availability of an app by calling HasXXX before getting its instance.
type Space interface { type Space interface {
event.Registry
GetApplication(appInterface interface{}) Application GetApplication(appInterface interface{}) Application
AddApplication(application Application) error AddApplication(application Application) error
Initialize() error Initialize() error
OnInitialize(InitializationCallback)
Start() error Start() error
Close() Close()
} }
const (
// SpaceInitializing is an event to be fired when Space is being initialized.
SpaceInitializing event.Event = iota
)
type spaceImpl struct { type spaceImpl struct {
initialized bool event.Listener
cache map[reflect.Type]Application cache map[reflect.Type]Application
appInit []InitializationCallback initialized bool
} }
// NewSpace creates a new Space.
func NewSpace() Space { func NewSpace() Space {
return &spaceImpl{ return &spaceImpl{
cache: make(map[reflect.Type]Application), cache: make(map[reflect.Type]Application),
appInit: make([]InitializationCallback, 0, 32),
} }
} }
func (s *spaceImpl) OnInitialize(f InitializationCallback) { func (s *spaceImpl) On(e event.Event, h event.Handler) {
if s.initialized { if e == SpaceInitializing && s.initialized {
f() _ = h(nil) // Ignore error
} else { return
s.appInit = append(s.appInit, f)
} }
s.Listener.On(e, h)
} }
func (s *spaceImpl) Initialize() error { func (s *spaceImpl) Initialize() error {
for _, f := range s.appInit { if s.initialized {
if err := f(); err != nil {
return err
}
}
s.appInit = nil
s.initialized = true
return nil return nil
} }
s.initialized = true
return s.Fire(SpaceInitializing, nil)
}
func (s *spaceImpl) GetApplication(appInterface interface{}) Application { func (s *spaceImpl) GetApplication(appInterface interface{}) Application {
if s == nil { if s == nil {

View File

@ -1,52 +0,0 @@
package vpndialer
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type Config struct {
Address string `protobuf:"bytes,1,opt,name=address" json:"address,omitempty"`
}
func (m *Config) Reset() { *m = Config{} }
func (m *Config) String() string { return proto.CompactTextString(m) }
func (*Config) ProtoMessage() {}
func (*Config) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
func (m *Config) GetAddress() string {
if m != nil {
return m.Address
}
return ""
}
func init() {
proto.RegisterType((*Config)(nil), "v2ray.core.app.vpndialer.Config")
}
func init() { proto.RegisterFile("v2ray.com/core/app/vpndialer/config.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
// 150 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xd2, 0x2c, 0x33, 0x2a, 0x4a,
0xac, 0xd4, 0x4b, 0xce, 0xcf, 0xd5, 0x4f, 0xce, 0x2f, 0x4a, 0xd5, 0x4f, 0x2c, 0x28, 0xd0, 0x2f,
0x2b, 0xc8, 0x4b, 0xc9, 0x4c, 0xcc, 0x49, 0x2d, 0xd2, 0x4f, 0xce, 0xcf, 0x4b, 0xcb, 0x4c, 0xd7,
0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x92, 0x80, 0x29, 0x2d, 0x4a, 0xd5, 0x4b, 0x2c, 0x28, 0xd0,
0x83, 0x2b, 0x53, 0x52, 0xe2, 0x62, 0x73, 0x06, 0xab, 0x14, 0x92, 0xe0, 0x62, 0x4f, 0x4c, 0x49,
0x29, 0x4a, 0x2d, 0x2e, 0x96, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x0c, 0x82, 0x71, 0x9d, 0xdc, 0xb8,
0x64, 0x92, 0xf3, 0x73, 0xf5, 0x70, 0x99, 0x11, 0xc0, 0x18, 0xc5, 0x09, 0xe7, 0xac, 0x62, 0x92,
0x08, 0x33, 0x0a, 0x4a, 0xac, 0xd4, 0x73, 0x06, 0xa9, 0x73, 0x2c, 0x28, 0xd0, 0x0b, 0x2b, 0xc8,
0x73, 0x01, 0x4b, 0x25, 0xb1, 0x81, 0x1d, 0x63, 0x0c, 0x08, 0x00, 0x00, 0xff, 0xff, 0x97, 0xfc,
0x09, 0x70, 0xb9, 0x00, 0x00, 0x00,
}

View File

@ -1,11 +0,0 @@
syntax = "proto3";
package v2ray.core.app.vpndialer;
option csharp_namespace = "V2Ray.Core.App.VpnDialer";
option go_package = "vpndialer";
option java_package = "com.v2ray.core.app.vpndialer";
option java_multiple_files = true;
message Config {
string address = 1;
}

View File

@ -1,7 +0,0 @@
package unix
import "v2ray.com/core/common/errors"
func newError(values ...interface{}) *errors.Error {
return errors.New(values...).Path("App", "VPNDialer", "Unix")
}

View File

@ -1,217 +0,0 @@
package unix
import (
"context"
"os"
"sync"
"golang.org/x/sys/unix"
"v2ray.com/core/app/vpndialer"
"v2ray.com/core/common"
"v2ray.com/core/common/net"
"v2ray.com/core/common/serial"
"v2ray.com/core/transport/internet"
)
//go:generate go run $GOPATH/src/v2ray.com/core/tools/generrorgen/main.go -pkg unix -path App,VPNDialer,Unix
type status int
const (
statusNew status = iota
statusOK
statusFail
)
type fdStatus struct {
status status
fd int
callback chan<- error
}
type protector struct {
sync.Mutex
address string
conn *net.UnixConn
status chan fdStatus
}
func readFrom(conn *net.UnixConn, schan chan<- fdStatus) {
var payload [6]byte
for {
_, err := conn.Read(payload[:])
if err != nil {
break
}
fd := serial.BytesToInt(payload[1:5])
s := status(payload[5])
schan <- fdStatus{
fd: fd,
status: s,
}
}
}
func (m *protector) dial() (*net.UnixConn, error) {
m.Lock()
defer m.Unlock()
if m.conn != nil {
return m.conn, nil
}
conn, err := net.DialUnix("unix", nil, &net.UnixAddr{
Name: m.address,
Net: "unix",
})
if err != nil {
return nil, err
}
m.conn = conn
m.status = make(chan fdStatus, 32)
go readFrom(conn, m.status)
go m.monitor(conn)
return conn, nil
}
func (m *protector) close() {
m.Lock()
defer m.Unlock()
if m.conn == nil {
return
}
m.conn.Close()
m.conn = nil
}
func (m *protector) monitor(c *net.UnixConn) {
pendingFd := make(map[int]chan<- error, 32)
for s := range m.status {
switch s.status {
case statusNew:
pendingFd[s.fd] = s.callback
case statusOK:
if c, f := pendingFd[s.fd]; f {
close(c)
delete(pendingFd, s.fd)
}
case statusFail:
if c, f := pendingFd[s.fd]; f {
c <- newError("failed to protect fd")
close(c)
delete(pendingFd, s.fd)
}
}
}
}
func (m *protector) protect(fd int) error {
conn, err := m.dial()
if err != nil {
return err
}
var payload [6]byte
serial.IntToBytes(fd, payload[1:1])
payload[5] = byte(statusNew)
if _, err := conn.Write(payload[:]); err != nil {
return err
}
wait := make(chan error)
m.status <- fdStatus{
status: statusNew,
fd: fd,
callback: wait,
}
return <-wait
}
type App struct {
protector *protector
dialer *Dialer
}
func NewApp(ctx context.Context, config *vpndialer.Config) (*App, error) {
a := &App{
dialer: &Dialer{},
protector: &protector{
address: config.Address,
},
}
a.dialer.protect = a.protector.protect
return a, nil
}
func (*App) Interface() interface{} {
return (*App)(nil)
}
func (a *App) Start() error {
internet.UseAlternativeSystemDialer(a.dialer)
return nil
}
func (a *App) Close() {
internet.UseAlternativeSystemDialer(nil)
}
type Dialer struct {
protect func(fd int) error
}
func socket(dest net.Destination) (int, error) {
switch dest.Network {
case net.Network_TCP:
return unix.Socket(unix.AF_INET6, unix.SOCK_STREAM, unix.IPPROTO_TCP)
case net.Network_UDP:
return unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
default:
return 0, newError("unknown network ", dest.Network)
}
}
func getIP(addr net.Address) (net.IP, error) {
if addr.Family().Either(net.AddressFamilyIPv4, net.AddressFamilyIPv6) {
return addr.IP(), nil
}
ips, err := net.LookupIP(addr.Domain())
if err != nil {
return nil, err
}
return ips[0], nil
}
func (d *Dialer) Dial(ctx context.Context, source net.Address, dest net.Destination) (net.Conn, error) {
fd, err := socket(dest)
if err != nil {
return nil, err
}
if err := d.protect(fd); err != nil {
return nil, err
}
ip, err := getIP(dest.Address)
if err != nil {
return nil, err
}
addr := &unix.SockaddrInet6{
Port: int(dest.Port),
ZoneId: 0,
}
copy(addr.Addr[:], ip.To16())
if err := unix.Connect(fd, addr); err != nil {
return nil, err
}
file := os.NewFile(uintptr(fd), "Socket")
return net.FileConn(file)
}
func init() {
common.Must(common.RegisterConfig((*vpndialer.Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
return NewApp(ctx, config.(*vpndialer.Config))
}))
}

View File

@ -1 +0,0 @@
package vpndialer

21
common/bitmask/byte.go Normal file
View File

@ -0,0 +1,21 @@
package bitmask
// Byte is a bitmask in byte.
type Byte byte
// Has returns true if this bitmask contains another bitmask.
func (b Byte) Has(bb Byte) bool {
return (b & bb) != 0
}
func (b *Byte) Set(bb Byte) {
*b |= bb
}
func (b *Byte) Clear(bb Byte) {
*b &= ^bb
}
func (b *Byte) Toggle(bb Byte) {
*b ^= bb
}

View File

@ -0,0 +1,27 @@
package bitmask_test
import (
"testing"
. "v2ray.com/core/common/bitmask"
. "v2ray.com/ext/assert"
)
func TestBitmaskByte(t *testing.T) {
assert := With(t)
b := Byte(0)
b.Set(Byte(1))
assert(b.Has(1), IsTrue)
b.Set(Byte(2))
assert(b.Has(2), IsTrue)
assert(b.Has(1), IsTrue)
b.Clear(Byte(1))
assert(b.Has(2), IsTrue)
assert(b.Has(1), IsFalse)
b.Toggle(Byte(2))
assert(b.Has(2), IsFalse)
}

View File

@ -1,10 +1,7 @@
package buf package buf
import ( import (
"runtime"
"sync" "sync"
"v2ray.com/core/common/platform"
) )
// Pool provides functionality to generate and recycle buffers on demand. // Pool provides functionality to generate and recycle buffers on demand.
@ -45,79 +42,11 @@ func (p *SyncPool) Free(buffer *Buffer) {
} }
} }
// BufferPool is a Pool that utilizes an internal cache.
type BufferPool struct {
chain chan []byte
sub Pool
}
// NewBufferPool creates a new BufferPool with given buffer size, and internal cache size.
func NewBufferPool(bufferSize, poolSize uint32) *BufferPool {
pool := &BufferPool{
chain: make(chan []byte, poolSize),
sub: NewSyncPool(bufferSize),
}
for i := uint32(0); i < poolSize; i++ {
pool.chain <- make([]byte, bufferSize)
}
return pool
}
// Allocate implements Pool.Allocate().
func (p *BufferPool) Allocate() *Buffer {
select {
case b := <-p.chain:
return &Buffer{
v: b,
pool: p,
}
default:
return p.sub.Allocate()
}
}
// Free implements Pool.Free().
func (p *BufferPool) Free(buffer *Buffer) {
if buffer.v == nil {
return
}
select {
case p.chain <- buffer.v:
default:
p.sub.Free(buffer)
}
}
const ( const (
// Size of a regular buffer. // Size of a regular buffer.
Size = 2 * 1024 Size = 2 * 1024
poolSizeEnvKey = "v2ray.buffer.size"
) )
var ( var (
mediumPool Pool mediumPool Pool = NewSyncPool(Size)
) )
func getDefaultPoolSize() int {
switch runtime.GOARCH {
case "amd64", "386":
return 20
default:
return 5
}
}
func init() {
f := platform.EnvFlag{
Name: poolSizeEnvKey,
AltName: platform.NormalizeEnvName(poolSizeEnvKey),
}
size := f.GetValueAsInt(getDefaultPoolSize())
if size > 0 {
totalByteSize := uint32(size) * 1024 * 1024
mediumPool = NewBufferPool(Size, totalByteSize/Size)
} else {
mediumPool = NewSyncPool(Size)
}
}

View File

@ -6,64 +6,64 @@ import (
. "v2ray.com/core/common/buf" . "v2ray.com/core/common/buf"
"v2ray.com/core/common/serial" "v2ray.com/core/common/serial"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func TestBufferClear(t *testing.T) { func TestBufferClear(t *testing.T) {
assert := assert.On(t) assert := With(t)
buffer := New() buffer := New()
defer buffer.Release() defer buffer.Release()
payload := "Bytes" payload := "Bytes"
buffer.Append([]byte(payload)) buffer.Append([]byte(payload))
assert.Int(buffer.Len()).Equals(len(payload)) assert(buffer.Len(), Equals, len(payload))
buffer.Clear() buffer.Clear()
assert.Int(buffer.Len()).Equals(0) assert(buffer.Len(), Equals, 0)
} }
func TestBufferIsEmpty(t *testing.T) { func TestBufferIsEmpty(t *testing.T) {
assert := assert.On(t) assert := With(t)
buffer := New() buffer := New()
defer buffer.Release() defer buffer.Release()
assert.Bool(buffer.IsEmpty()).IsTrue() assert(buffer.IsEmpty(), IsTrue)
} }
func TestBufferString(t *testing.T) { func TestBufferString(t *testing.T) {
assert := assert.On(t) assert := With(t)
buffer := New() buffer := New()
defer buffer.Release() defer buffer.Release()
assert.Error(buffer.AppendSupplier(serial.WriteString("Test String"))).IsNil() assert(buffer.AppendSupplier(serial.WriteString("Test String")), IsNil)
assert.String(buffer.String()).Equals("Test String") assert(buffer.String(), Equals, "Test String")
} }
func TestBufferWrite(t *testing.T) { func TestBufferWrite(t *testing.T) {
assert := assert.On(t) assert := With(t)
buffer := NewLocal(8) buffer := NewLocal(8)
nBytes, err := buffer.Write([]byte("abcd")) nBytes, err := buffer.Write([]byte("abcd"))
assert.Error(err).IsNil() assert(err, IsNil)
assert.Int(nBytes).Equals(4) assert(nBytes, Equals, 4)
nBytes, err = buffer.Write([]byte("abcde")) nBytes, err = buffer.Write([]byte("abcde"))
assert.Error(err).IsNil() assert(err, IsNil)
assert.Int(nBytes).Equals(4) assert(nBytes, Equals, 4)
assert.String(buffer.String()).Equals("abcdabcd") assert(buffer.String(), Equals, "abcdabcd")
} }
func TestSyncPool(t *testing.T) { func TestSyncPool(t *testing.T) {
assert := assert.On(t) assert := With(t)
p := NewSyncPool(32) p := NewSyncPool(32)
b := p.Allocate() b := p.Allocate()
assert.Int(b.Len()).Equals(0) assert(b.Len(), Equals, 0)
assert.Error(b.AppendSupplier(ReadFrom(rand.Reader))).IsNil() assert(b.AppendSupplier(ReadFrom(rand.Reader)), IsNil)
assert.Int(b.Len()).Equals(32) assert(b.Len(), Equals, 32)
b.Release() b.Release()
} }

View File

@ -1,53 +0,0 @@
package buf
import (
"io"
)
// BufferedReader is a reader with internal cache.
type BufferedReader struct {
reader io.Reader
buffer *Buffer
buffered bool
}
// NewBufferedReader creates a new BufferedReader based on an io.Reader.
func NewBufferedReader(rawReader io.Reader) *BufferedReader {
return &BufferedReader{
reader: rawReader,
buffer: NewLocal(1024),
buffered: true,
}
}
// IsBuffered returns true if the internal cache is effective.
func (v *BufferedReader) IsBuffered() bool {
return v.buffered
}
// SetBuffered is to enable or disable internal cache. If cache is disabled,
// Read() calls will be delegated to the underlying io.Reader directly.
func (v *BufferedReader) SetBuffered(cached bool) {
v.buffered = cached
}
// Read implements io.Reader.Read().
func (v *BufferedReader) Read(b []byte) (int, error) {
if !v.buffered || v.buffer == nil {
if !v.buffer.IsEmpty() {
return v.buffer.Read(b)
}
return v.reader.Read(b)
}
if v.buffer.IsEmpty() {
if err := v.buffer.Reset(ReadFrom(v.reader)); err != nil {
return 0, err
}
}
if v.buffer.IsEmpty() {
return 0, nil
}
return v.buffer.Read(b)
}

View File

@ -1,36 +0,0 @@
package buf_test
import (
"crypto/rand"
"testing"
. "v2ray.com/core/common/buf"
"v2ray.com/core/testing/assert"
)
func TestBufferedReader(t *testing.T) {
assert := assert.On(t)
content := New()
assert.Error(content.AppendSupplier(ReadFrom(rand.Reader))).IsNil()
len := content.Len()
reader := NewBufferedReader(content)
assert.Bool(reader.IsBuffered()).IsTrue()
payload := make([]byte, 16)
nBytes, err := reader.Read(payload)
assert.Int(nBytes).Equals(16)
assert.Error(err).IsNil()
len2 := content.Len()
assert.Int(len - len2).GreaterThan(16)
nBytes, err = reader.Read(payload)
assert.Int(nBytes).Equals(16)
assert.Error(err).IsNil()
assert.Int(content.Len()).Equals(len2)
}

View File

@ -1,72 +0,0 @@
package buf
import "io"
// BufferedWriter is an io.Writer with internal buffer. It writes to underlying writer when buffer is full or on demand.
// This type is not thread safe.
type BufferedWriter struct {
writer io.Writer
buffer *Buffer
buffered bool
}
// NewBufferedWriter creates a new BufferedWriter.
func NewBufferedWriter(writer io.Writer) *BufferedWriter {
return NewBufferedWriterSize(writer, 1024)
}
func NewBufferedWriterSize(writer io.Writer, size uint32) *BufferedWriter {
return &BufferedWriter{
writer: writer,
buffer: NewLocal(int(size)),
buffered: true,
}
}
// Write implements io.Writer.
func (w *BufferedWriter) Write(b []byte) (int, error) {
if !w.buffered || w.buffer == nil {
return w.writer.Write(b)
}
bytesWritten := 0
for bytesWritten < len(b) {
nBytes, err := w.buffer.Write(b[bytesWritten:])
if err != nil {
return bytesWritten, err
}
bytesWritten += nBytes
if w.buffer.IsFull() {
if err := w.Flush(); err != nil {
return bytesWritten, err
}
}
}
return bytesWritten, nil
}
// Flush writes all buffered content into underlying writer, if any.
func (w *BufferedWriter) Flush() error {
defer w.buffer.Clear()
for !w.buffer.IsEmpty() {
nBytes, err := w.writer.Write(w.buffer.Bytes())
if err != nil {
return err
}
w.buffer.SliceFrom(nBytes)
}
return nil
}
// IsBuffered returns true if this BufferedWriter holds a buffer.
func (w *BufferedWriter) IsBuffered() bool {
return w.buffered
}
// SetBuffered controls whether the BufferedWriter holds a buffer for writing. If not buffered, any write() calls into underlying writer directly.
func (w *BufferedWriter) SetBuffered(cached bool) error {
w.buffered = cached
if !cached && !w.buffer.IsEmpty() {
return w.Flush()
}
return nil
}

View File

@ -1,53 +0,0 @@
package buf_test
import (
"crypto/rand"
"testing"
. "v2ray.com/core/common/buf"
"v2ray.com/core/testing/assert"
)
func TestBufferedWriter(t *testing.T) {
assert := assert.On(t)
content := New()
writer := NewBufferedWriter(content)
assert.Bool(writer.IsBuffered()).IsTrue()
payload := make([]byte, 16)
nBytes, err := writer.Write(payload)
assert.Int(nBytes).Equals(16)
assert.Error(err).IsNil()
assert.Bool(content.IsEmpty()).IsTrue()
assert.Error(writer.SetBuffered(false)).IsNil()
assert.Int(content.Len()).Equals(16)
}
func TestBufferedWriterLargePayload(t *testing.T) {
assert := assert.On(t)
content := NewLocal(128 * 1024)
writer := NewBufferedWriter(content)
assert.Bool(writer.IsBuffered()).IsTrue()
payload := make([]byte, 64*1024)
rand.Read(payload)
nBytes, err := writer.Write(payload[:512])
assert.Int(nBytes).Equals(512)
assert.Error(err).IsNil()
assert.Bool(content.IsEmpty()).IsTrue()
nBytes, err = writer.Write(payload[512:])
assert.Error(err).IsNil()
assert.Error(writer.Flush()).IsNil()
assert.Int(nBytes).Equals(64*1024 - 512)
assert.Bytes(content.Bytes()).Equals(payload)
}

View File

@ -17,7 +17,7 @@ type copyHandler struct {
} }
func (h *copyHandler) readFrom(reader Reader) (MultiBuffer, error) { func (h *copyHandler) readFrom(reader Reader) (MultiBuffer, error) {
mb, err := reader.Read() mb, err := reader.ReadMultiBuffer()
if err != nil { if err != nil {
for _, handler := range h.onReadError { for _, handler := range h.onReadError {
err = handler(err) err = handler(err)
@ -27,7 +27,7 @@ func (h *copyHandler) readFrom(reader Reader) (MultiBuffer, error) {
} }
func (h *copyHandler) writeTo(writer Writer, mb MultiBuffer) error { func (h *copyHandler) writeTo(writer Writer, mb MultiBuffer) error {
err := writer.Write(mb) err := writer.WriteMultiBuffer(mb)
if err != nil { if err != nil {
for _, handler := range h.onWriteError { for _, handler := range h.onWriteError {
err = handler(err) err = handler(err)
@ -36,8 +36,14 @@ func (h *copyHandler) writeTo(writer Writer, mb MultiBuffer) error {
return err return err
} }
type SizeCounter struct {
Size int64
}
// CopyOption is an option for copying data.
type CopyOption func(*copyHandler) type CopyOption func(*copyHandler)
// IgnoreReaderError is a CopyOption that ignores errors from reader. Copy will continue in such case.
func IgnoreReaderError() CopyOption { func IgnoreReaderError() CopyOption {
return func(handler *copyHandler) { return func(handler *copyHandler) {
handler.onReadError = append(handler.onReadError, func(err error) error { handler.onReadError = append(handler.onReadError, func(err error) error {
@ -46,6 +52,7 @@ func IgnoreReaderError() CopyOption {
} }
} }
// IgnoreWriterError is a CopyOption that ignores errors from writer. Copy will continue in such case.
func IgnoreWriterError() CopyOption { func IgnoreWriterError() CopyOption {
return func(handler *copyHandler) { return func(handler *copyHandler) {
handler.onWriteError = append(handler.onWriteError, func(err error) error { handler.onWriteError = append(handler.onWriteError, func(err error) error {
@ -54,6 +61,7 @@ func IgnoreWriterError() CopyOption {
} }
} }
// UpdateActivity is a CopyOption to update activity on each data copy operation.
func UpdateActivity(timer signal.ActivityUpdater) CopyOption { func UpdateActivity(timer signal.ActivityUpdater) CopyOption {
return func(handler *copyHandler) { return func(handler *copyHandler) {
handler.onData = append(handler.onData, func(MultiBuffer) { handler.onData = append(handler.onData, func(MultiBuffer) {
@ -62,31 +70,34 @@ func UpdateActivity(timer signal.ActivityUpdater) CopyOption {
} }
} }
// CountSize is a CopyOption that sums the total size of data copied into the given SizeCounter.
func CountSize(sc *SizeCounter) CopyOption {
return func(handler *copyHandler) {
handler.onData = append(handler.onData, func(b MultiBuffer) {
sc.Size += int64(b.Len())
})
}
}
func copyInternal(reader Reader, writer Writer, handler *copyHandler) error { func copyInternal(reader Reader, writer Writer, handler *copyHandler) error {
for { for {
buffer, err := handler.readFrom(reader) buffer, err := handler.readFrom(reader)
if err != nil { if !buffer.IsEmpty() {
return err
}
if buffer.IsEmpty() {
buffer.Release()
continue
}
for _, handler := range handler.onData { for _, handler := range handler.onData {
handler(buffer) handler(buffer)
} }
if err := handler.writeTo(writer, buffer); err != nil { if werr := handler.writeTo(writer, buffer); werr != nil {
buffer.Release() buffer.Release()
return werr
}
} else if err != nil {
return err return err
} }
} }
} }
// Copy dumps all payload from reader to writer or stops when an error occurs. // Copy dumps all payload from reader to writer or stops when an error occurs. It returns nil when EOF.
// ActivityTimer gets updated as soon as there is a payload.
func Copy(reader Reader, writer Writer, options ...CopyOption) error { func Copy(reader Reader, writer Writer, options ...CopyOption) error {
handler := new(copyHandler) handler := new(copyHandler)
for _, option := range options { for _, option := range options {

View File

@ -5,22 +5,24 @@ import (
"time" "time"
) )
// Reader extends io.Reader with alloc.Buffer. // Reader extends io.Reader with MultiBuffer.
type Reader interface { type Reader interface {
// Read reads content from underlying reader, and put it into an alloc.Buffer. // ReadMultiBuffer reads content from underlying reader, and put it into a MultiBuffer.
Read() (MultiBuffer, error) ReadMultiBuffer() (MultiBuffer, error)
} }
// ErrReadTimeout is an error that happens with IO timeout.
var ErrReadTimeout = newError("IO timeout") var ErrReadTimeout = newError("IO timeout")
// TimeoutReader is a reader that returns error if Read() operation takes longer than the given timeout.
type TimeoutReader interface { type TimeoutReader interface {
ReadTimeout(time.Duration) (MultiBuffer, error) ReadTimeout(time.Duration) (MultiBuffer, error)
} }
// Writer extends io.Writer with alloc.Buffer. // Writer extends io.Writer with MultiBuffer.
type Writer interface { type Writer interface {
// Write writes an alloc.Buffer into underlying writer. // WriteMultiBuffer writes a MultiBuffer into underlying writer.
Write(MultiBuffer) error WriteMultiBuffer(MultiBuffer) error
} }
// ReadFrom creates a Supplier to read from a given io.Reader. // ReadFrom creates a Supplier to read from a given io.Reader.
@ -47,57 +49,21 @@ func ReadAtLeastFrom(reader io.Reader, size int) Supplier {
// NewReader creates a new Reader. // NewReader creates a new Reader.
// The Reader instance doesn't take the ownership of reader. // The Reader instance doesn't take the ownership of reader.
func NewReader(reader io.Reader) Reader { func NewReader(reader io.Reader) Reader {
if mr, ok := reader.(MultiBufferReader); ok { if mr, ok := reader.(Reader); ok {
return &readerAdpater{ return mr
MultiBufferReader: mr,
}
} }
return &BytesToBufferReader{ return NewBytesToBufferReader(reader)
reader: reader,
buffer: make([]byte, 32*1024),
}
}
func NewMergingReader(reader io.Reader) Reader {
return NewMergingReaderSize(reader, 32*1024)
}
func NewMergingReaderSize(reader io.Reader, size uint32) Reader {
return &BytesToBufferReader{
reader: reader,
buffer: make([]byte, size),
}
}
// ToBytesReader converts a Reaaer to io.Reader.
func ToBytesReader(stream Reader) io.Reader {
return &bufferToBytesReader{
stream: stream,
}
} }
// NewWriter creates a new Writer. // NewWriter creates a new Writer.
func NewWriter(writer io.Writer) Writer { func NewWriter(writer io.Writer) Writer {
if mw, ok := writer.(MultiBufferWriter); ok { if mw, ok := writer.(Writer); ok {
return &writerAdapter{ return mw
writer: mw,
}
} }
return &BufferToBytesWriter{ return &BufferToBytesWriter{
writer: writer, Writer: writer,
}
}
func NewMergingWriter(writer io.Writer) Writer {
return NewMergingWriterSize(writer, 4096)
}
func NewMergingWriterSize(writer io.Writer, size uint32) Writer {
return &mergingWriter{
writer: writer,
buffer: make([]byte, size),
} }
} }
@ -106,10 +72,3 @@ func NewSequentialWriter(writer io.Writer) Writer {
writer: writer, writer: writer,
} }
} }
// ToBytesWriter converts a Writer to io.Writer
func ToBytesWriter(writer Writer) io.Writer {
return &bytesToBufferWriter{
writer: writer,
}
}

View File

@ -1,21 +1,53 @@
package buf package buf
import "net" import (
"io"
"net"
type MultiBufferWriter interface { "v2ray.com/core/common"
WriteMultiBuffer(MultiBuffer) error "v2ray.com/core/common/errors"
)
// ReadAllToMultiBuffer reads all content from the reader into a MultiBuffer, until EOF.
func ReadAllToMultiBuffer(reader io.Reader) (MultiBuffer, error) {
mb := NewMultiBufferCap(128)
for {
b := New()
err := b.Reset(ReadFrom(reader))
if b.IsEmpty() {
b.Release()
} else {
mb.Append(b)
}
if err != nil {
if errors.Cause(err) == io.EOF {
return mb, nil
}
mb.Release()
return nil, err
}
}
} }
type MultiBufferReader interface { // ReadAllToBytes reads all content from the reader into a byte array, until EOF.
ReadMultiBuffer() (MultiBuffer, error) func ReadAllToBytes(reader io.Reader) ([]byte, error) {
mb, err := ReadAllToMultiBuffer(reader)
if err != nil {
return nil, err
}
b := make([]byte, mb.Len())
common.Must2(mb.Read(b))
mb.Release()
return b, nil
} }
// MultiBuffer is a list of Buffers. The order of Buffer matters. // MultiBuffer is a list of Buffers. The order of Buffer matters.
type MultiBuffer []*Buffer type MultiBuffer []*Buffer
// NewMultiBuffer creates a new MultiBuffer instance. // NewMultiBufferCap creates a new MultiBuffer instance.
func NewMultiBuffer() MultiBuffer { func NewMultiBufferCap(capacity int) MultiBuffer {
return MultiBuffer(make([]*Buffer, 0, 128)) return MultiBuffer(make([]*Buffer, 0, capacity))
} }
// NewMultiBufferValue wraps a list of Buffers into MultiBuffer. // NewMultiBufferValue wraps a list of Buffers into MultiBuffer.
@ -23,14 +55,17 @@ func NewMultiBufferValue(b ...*Buffer) MultiBuffer {
return MultiBuffer(b) return MultiBuffer(b)
} }
// Append appends buffer to the end of this MultiBuffer
func (mb *MultiBuffer) Append(buf *Buffer) { func (mb *MultiBuffer) Append(buf *Buffer) {
*mb = append(*mb, buf) *mb = append(*mb, buf)
} }
// AppendMulti appends a MultiBuffer to the end of this one.
func (mb *MultiBuffer) AppendMulti(buf MultiBuffer) { func (mb *MultiBuffer) AppendMulti(buf MultiBuffer) {
*mb = append(*mb, buf...) *mb = append(*mb, buf...)
} }
// Copy copied the begining part of the MultiBuffer into the given byte array.
func (mb MultiBuffer) Copy(b []byte) int { func (mb MultiBuffer) Copy(b []byte) int {
total := 0 total := 0
for _, bb := range mb { for _, bb := range mb {
@ -43,6 +78,7 @@ func (mb MultiBuffer) Copy(b []byte) int {
return total return total
} }
// Read implements io.Reader.
func (mb *MultiBuffer) Read(b []byte) (int, error) { func (mb *MultiBuffer) Read(b []byte) (int, error) {
endIndex := len(*mb) endIndex := len(*mb)
totalBytes := 0 totalBytes := 0
@ -52,6 +88,7 @@ func (mb *MultiBuffer) Read(b []byte) (int, error) {
b = b[nBytes:] b = b[nBytes:]
if bb.IsEmpty() { if bb.IsEmpty() {
bb.Release() bb.Release()
(*mb)[i] = nil
} else { } else {
endIndex = i endIndex = i
break break
@ -61,6 +98,7 @@ func (mb *MultiBuffer) Read(b []byte) (int, error) {
return totalBytes, nil return totalBytes, nil
} }
// Write implements io.Writer.
func (mb *MultiBuffer) Write(b []byte) { func (mb *MultiBuffer) Write(b []byte) {
n := len(*mb) n := len(*mb)
if n > 0 && !(*mb)[n-1].IsFull() { if n > 0 && !(*mb)[n-1].IsFull() {
@ -96,11 +134,12 @@ func (mb MultiBuffer) IsEmpty() bool {
} }
// Release releases all Buffers in the MultiBuffer. // Release releases all Buffers in the MultiBuffer.
func (mb MultiBuffer) Release() { func (mb *MultiBuffer) Release() {
for i, b := range mb { for i, b := range *mb {
b.Release() b.Release()
mb[i] = nil (*mb)[i] = nil
} }
*mb = nil
} }
// ToNetBuffers converts this MultiBuffer to net.Buffers. The return net.Buffers points to the same content of the MultiBuffer. // ToNetBuffers converts this MultiBuffer to net.Buffers. The return net.Buffers points to the same content of the MultiBuffer.
@ -112,8 +151,9 @@ func (mb MultiBuffer) ToNetBuffers() net.Buffers {
return bs return bs
} }
// SliceBySize splits the begining of this MultiBuffer into another one, for at most size bytes.
func (mb *MultiBuffer) SliceBySize(size int) MultiBuffer { func (mb *MultiBuffer) SliceBySize(size int) MultiBuffer {
slice := NewMultiBuffer() slice := NewMultiBufferCap(10)
sliceSize := 0 sliceSize := 0
endIndex := len(*mb) endIndex := len(*mb)
for i, b := range *mb { for i, b := range *mb {
@ -123,16 +163,24 @@ func (mb *MultiBuffer) SliceBySize(size int) MultiBuffer {
} }
sliceSize += b.Len() sliceSize += b.Len()
slice.Append(b) slice.Append(b)
(*mb)[i] = nil
} }
*mb = (*mb)[endIndex:] *mb = (*mb)[endIndex:]
if endIndex == 0 && len(*mb) > 0 {
b := New()
common.Must(b.Reset(ReadFullFrom((*mb)[0], size)))
return NewMultiBufferValue(b)
}
return slice return slice
} }
// SplitFirst splits out the first Buffer in this MultiBuffer.
func (mb *MultiBuffer) SplitFirst() *Buffer { func (mb *MultiBuffer) SplitFirst() *Buffer {
if len(*mb) == 0 { if len(*mb) == 0 {
return nil return nil
} }
b := (*mb)[0] b := (*mb)[0]
(*mb)[0] = nil
*mb = (*mb)[1:] *mb = (*mb)[1:]
return b return b
} }

View File

@ -4,11 +4,11 @@ import (
"testing" "testing"
. "v2ray.com/core/common/buf" . "v2ray.com/core/common/buf"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func TestMultiBufferRead(t *testing.T) { func TestMultiBufferRead(t *testing.T) {
assert := assert.On(t) assert := With(t)
b1 := New() b1 := New()
b1.AppendBytes('a', 'b') b1.AppendBytes('a', 'b')
@ -19,17 +19,17 @@ func TestMultiBufferRead(t *testing.T) {
bs := make([]byte, 32) bs := make([]byte, 32)
nBytes, err := mb.Read(bs) nBytes, err := mb.Read(bs)
assert.Error(err).IsNil() assert(err, IsNil)
assert.Int(nBytes).Equals(4) assert(nBytes, Equals, 4)
assert.Bytes(bs[:nBytes]).Equals([]byte("abcd")) assert(bs[:nBytes], Equals, []byte("abcd"))
} }
func TestMultiBufferAppend(t *testing.T) { func TestMultiBufferAppend(t *testing.T) {
assert := assert.On(t) assert := With(t)
var mb MultiBuffer var mb MultiBuffer
b := New() b := New()
b.AppendBytes('a', 'b') b.AppendBytes('a', 'b')
mb.Append(b) mb.Append(b)
assert.Int(mb.Len()).Equals(2) assert(mb.Len(), Equals, 2)
} }

View File

@ -6,38 +6,86 @@ import (
"v2ray.com/core/common/errors" "v2ray.com/core/common/errors"
) )
var (
_ Reader = (*BytesToBufferReader)(nil)
_ io.Reader = (*BytesToBufferReader)(nil)
)
// BytesToBufferReader is a Reader that adjusts its reading speed automatically. // BytesToBufferReader is a Reader that adjusts its reading speed automatically.
type BytesToBufferReader struct { type BytesToBufferReader struct {
reader io.Reader io.Reader
buffer []byte buffer []byte
} }
// Read implements Reader.Read(). func NewBytesToBufferReader(reader io.Reader) Reader {
func (r *BytesToBufferReader) Read() (MultiBuffer, error) { return &BytesToBufferReader{
nBytes, err := r.reader.Read(r.buffer) Reader: reader,
if err != nil { }
}
func (r *BytesToBufferReader) readSmall() (MultiBuffer, error) {
b := New()
err := b.Reset(ReadFrom(r.Reader))
if b.IsFull() {
r.buffer = make([]byte, 32*1024)
}
if !b.IsEmpty() {
return NewMultiBufferValue(b), nil
}
b.Release()
return nil, err return nil, err
} }
mb := NewMultiBuffer() // ReadMultiBuffer implements Reader.
func (r *BytesToBufferReader) ReadMultiBuffer() (MultiBuffer, error) {
if r.buffer == nil {
return r.readSmall()
}
nBytes, err := r.Reader.Read(r.buffer)
if nBytes > 0 {
mb := NewMultiBufferCap(nBytes/Size + 1)
mb.Write(r.buffer[:nBytes]) mb.Write(r.buffer[:nBytes])
return mb, nil return mb, nil
} }
return nil, err
type readerAdpater struct {
MultiBufferReader
} }
func (r *readerAdpater) Read() (MultiBuffer, error) { var (
return r.ReadMultiBuffer() _ Reader = (*BufferedReader)(nil)
} _ io.Reader = (*BufferedReader)(nil)
_ io.ByteReader = (*BufferedReader)(nil)
_ io.WriterTo = (*BufferedReader)(nil)
)
type bufferToBytesReader struct { type BufferedReader struct {
stream Reader stream Reader
leftOver MultiBuffer leftOver MultiBuffer
buffered bool
} }
func (r *bufferToBytesReader) Read(b []byte) (int, error) { func NewBufferedReader(reader Reader) *BufferedReader {
return &BufferedReader{
stream: reader,
buffered: true,
}
}
func (r *BufferedReader) SetBuffered(f bool) {
r.buffered = f
}
func (r *BufferedReader) IsBuffered() bool {
return r.buffered
}
func (r *BufferedReader) ReadByte() (byte, error) {
var b [1]byte
_, err := r.Read(b[:])
return b[0], err
}
func (r *BufferedReader) Read(b []byte) (int, error) {
if r.leftOver != nil { if r.leftOver != nil {
nBytes, _ := r.leftOver.Read(b) nBytes, _ := r.leftOver.Read(b)
if r.leftOver.IsEmpty() { if r.leftOver.IsEmpty() {
@ -47,51 +95,75 @@ func (r *bufferToBytesReader) Read(b []byte) (int, error) {
return nBytes, nil return nBytes, nil
} }
mb, err := r.stream.Read() if !r.buffered {
if err != nil { if reader, ok := r.stream.(io.Reader); ok {
return 0, err return reader.Read(b)
}
} }
mb, err := r.stream.ReadMultiBuffer()
if mb != nil {
nBytes, _ := mb.Read(b) nBytes, _ := mb.Read(b)
if !mb.IsEmpty() { if !mb.IsEmpty() {
r.leftOver = mb r.leftOver = mb
} }
return nBytes, nil return nBytes, err
}
return 0, err
} }
func (r *bufferToBytesReader) ReadMultiBuffer() (MultiBuffer, error) { func (r *BufferedReader) ReadMultiBuffer() (MultiBuffer, error) {
if r.leftOver != nil { if r.leftOver != nil {
mb := r.leftOver mb := r.leftOver
r.leftOver = nil r.leftOver = nil
return mb, nil return mb, nil
} }
return r.stream.Read() return r.stream.ReadMultiBuffer()
} }
func (r *bufferToBytesReader) writeToInternal(writer io.Writer) (int64, error) { // ReadAtMost returns a MultiBuffer with at most size.
func (r *BufferedReader) ReadAtMost(size int) (MultiBuffer, error) {
if r.leftOver == nil {
mb, err := r.stream.ReadMultiBuffer()
if mb.IsEmpty() && err != nil {
return nil, err
}
r.leftOver = mb
}
mb := r.leftOver.SliceBySize(size)
if r.leftOver.IsEmpty() {
r.leftOver = nil
}
return mb, nil
}
func (r *BufferedReader) writeToInternal(writer io.Writer) (int64, error) {
mbWriter := NewWriter(writer) mbWriter := NewWriter(writer)
totalBytes := int64(0) totalBytes := int64(0)
if r.leftOver != nil { if r.leftOver != nil {
if err := mbWriter.Write(r.leftOver); err != nil { totalBytes += int64(r.leftOver.Len())
if err := mbWriter.WriteMultiBuffer(r.leftOver); err != nil {
return 0, err return 0, err
} }
totalBytes += int64(r.leftOver.Len())
} }
for { for {
mb, err := r.stream.Read() mb, err := r.stream.ReadMultiBuffer()
if err != nil { if mb != nil {
totalBytes += int64(mb.Len())
if werr := mbWriter.WriteMultiBuffer(mb); werr != nil {
return totalBytes, err return totalBytes, err
} }
totalBytes += int64(mb.Len()) }
if err := mbWriter.Write(mb); err != nil { if err != nil {
return totalBytes, err return totalBytes, err
} }
} }
} }
func (r *bufferToBytesReader) WriteTo(writer io.Writer) (int64, error) { func (r *BufferedReader) WriteTo(writer io.Writer) (int64, error) {
nBytes, err := r.writeToInternal(writer) nBytes, err := r.writeToInternal(writer)
if errors.Cause(err) == io.EOF { if errors.Cause(err) == io.EOF {
return nBytes, nil return nBytes, nil

View File

@ -7,64 +7,66 @@ import (
"testing" "testing"
. "v2ray.com/core/common/buf" . "v2ray.com/core/common/buf"
"v2ray.com/core/testing/assert"
"v2ray.com/core/transport/ray" "v2ray.com/core/transport/ray"
. "v2ray.com/ext/assert"
) )
func TestAdaptiveReader(t *testing.T) { func TestAdaptiveReader(t *testing.T) {
assert := assert.On(t) assert := With(t)
rawContent := make([]byte, 1024*1024) reader := NewReader(bytes.NewReader(make([]byte, 1024*1024)))
buffer := bytes.NewBuffer(rawContent) b, err := reader.ReadMultiBuffer()
assert(err, IsNil)
assert(b.Len(), Equals, 2*1024)
reader := NewReader(buffer) b, err = reader.ReadMultiBuffer()
b, err := reader.Read() assert(err, IsNil)
assert.Error(err).IsNil() assert(b.Len(), Equals, 32*1024)
assert.Int(b.Len()).Equals(32 * 1024)
} }
func TestBytesReaderWriteTo(t *testing.T) { func TestBytesReaderWriteTo(t *testing.T) {
assert := assert.On(t) assert := With(t)
stream := ray.NewStream(context.Background()) stream := ray.NewStream(context.Background())
reader := ToBytesReader(stream) reader := NewBufferedReader(stream)
b1 := New() b1 := New()
b1.AppendBytes('a', 'b', 'c') b1.AppendBytes('a', 'b', 'c')
b2 := New() b2 := New()
b2.AppendBytes('e', 'f', 'g') b2.AppendBytes('e', 'f', 'g')
assert.Error(stream.Write(NewMultiBufferValue(b1, b2))).IsNil() assert(stream.WriteMultiBuffer(NewMultiBufferValue(b1, b2)), IsNil)
stream.Close() stream.Close()
stream2 := ray.NewStream(context.Background()) stream2 := ray.NewStream(context.Background())
writer := ToBytesWriter(stream2) writer := NewBufferedWriter(stream2)
writer.SetBuffered(false)
nBytes, err := io.Copy(writer, reader) nBytes, err := io.Copy(writer, reader)
assert.Error(err).IsNil() assert(err, IsNil)
assert.Int64(nBytes).Equals(6) assert(nBytes, Equals, int64(6))
mb, err := stream2.Read() mb, err := stream2.ReadMultiBuffer()
assert.Error(err).IsNil() assert(err, IsNil)
assert.Int(len(mb)).Equals(2) assert(len(mb), Equals, 2)
assert.String(mb[0].String()).Equals("abc") assert(mb[0].String(), Equals, "abc")
assert.String(mb[1].String()).Equals("efg") assert(mb[1].String(), Equals, "efg")
} }
func TestBytesReaderMultiBuffer(t *testing.T) { func TestBytesReaderMultiBuffer(t *testing.T) {
assert := assert.On(t) assert := With(t)
stream := ray.NewStream(context.Background()) stream := ray.NewStream(context.Background())
reader := ToBytesReader(stream) reader := NewBufferedReader(stream)
b1 := New() b1 := New()
b1.AppendBytes('a', 'b', 'c') b1.AppendBytes('a', 'b', 'c')
b2 := New() b2 := New()
b2.AppendBytes('e', 'f', 'g') b2.AppendBytes('e', 'f', 'g')
assert.Error(stream.Write(NewMultiBufferValue(b1, b2))).IsNil() assert(stream.WriteMultiBuffer(NewMultiBufferValue(b1, b2)), IsNil)
stream.Close() stream.Close()
mbReader := NewReader(reader) mbReader := NewReader(reader)
mb, err := mbReader.Read() mb, err := mbReader.ReadMultiBuffer()
assert.Error(err).IsNil() assert(err, IsNil)
assert.Int(len(mb)).Equals(2) assert(len(mb), Equals, 2)
assert.String(mb[0].String()).Equals("abc") assert(mb[0].String(), Equals, "abc")
assert.String(mb[1].String()).Equals("efg") assert(mb[1].String(), Equals, "efg")
} }

View File

@ -1,52 +1,164 @@
package buf package buf
import "io" import (
"io"
"v2ray.com/core/common/errors"
)
var (
_ io.ReaderFrom = (*BufferToBytesWriter)(nil)
_ io.Writer = (*BufferToBytesWriter)(nil)
_ Writer = (*BufferToBytesWriter)(nil)
)
// BufferToBytesWriter is a Writer that writes alloc.Buffer into underlying writer. // BufferToBytesWriter is a Writer that writes alloc.Buffer into underlying writer.
type BufferToBytesWriter struct { type BufferToBytesWriter struct {
writer io.Writer io.Writer
} }
// Write implements Writer.Write(). Write() takes ownership of the given buffer. func NewBufferToBytesWriter(writer io.Writer) *BufferToBytesWriter {
func (w *BufferToBytesWriter) Write(mb MultiBuffer) error { return &BufferToBytesWriter{
Writer: writer,
}
}
// WriteMultiBuffer implements Writer. This method takes ownership of the given buffer.
func (w *BufferToBytesWriter) WriteMultiBuffer(mb MultiBuffer) error {
defer mb.Release() defer mb.Release()
bs := mb.ToNetBuffers() bs := mb.ToNetBuffers()
_, err := bs.WriteTo(w.writer) _, err := bs.WriteTo(w)
return err return err
} }
type writerAdapter struct { // ReadFrom implements io.ReaderFrom.
writer MultiBufferWriter func (w *BufferToBytesWriter) ReadFrom(reader io.Reader) (int64, error) {
var sc SizeCounter
err := Copy(NewReader(reader), w, CountSize(&sc))
return sc.Size, err
} }
// Write implements buf.MultiBufferWriter. var (
func (w *writerAdapter) Write(mb MultiBuffer) error { _ io.ReaderFrom = (*BufferedWriter)(nil)
return w.writer.WriteMultiBuffer(mb) _ io.Writer = (*BufferedWriter)(nil)
_ Writer = (*BufferedWriter)(nil)
_ io.ByteWriter = (*BufferedWriter)(nil)
)
// BufferedWriter is a Writer with internal buffer.
type BufferedWriter struct {
writer Writer
buffer *Buffer
buffered bool
} }
type mergingWriter struct { // NewBufferedWriter creates a new BufferedWriter.
writer io.Writer func NewBufferedWriter(writer Writer) *BufferedWriter {
buffer []byte return &BufferedWriter{
writer: writer,
buffer: New(),
buffered: true,
}
} }
func (w *mergingWriter) Write(mb MultiBuffer) error { func (w *BufferedWriter) WriteByte(c byte) error {
defer mb.Release() _, err := w.Write([]byte{c})
for !mb.IsEmpty() {
nBytes, _ := mb.Read(w.buffer)
if _, err := w.writer.Write(w.buffer[:nBytes]); err != nil {
return err return err
} }
// Write implements io.Writer.
func (w *BufferedWriter) Write(b []byte) (int, error) {
if !w.buffered {
if writer, ok := w.writer.(io.Writer); ok {
return writer.Write(b)
}
}
totalBytes := 0
for len(b) > 0 {
if w.buffer == nil {
w.buffer = New()
}
nBytes, err := w.buffer.Write(b)
totalBytes += nBytes
if err != nil {
return totalBytes, err
}
if !w.buffered || w.buffer.IsFull() {
if err := w.Flush(); err != nil {
return totalBytes, err
}
}
b = b[nBytes:]
}
return totalBytes, nil
}
// WriteMultiBuffer implements Writer. It takes ownership of the given MultiBuffer.
func (w *BufferedWriter) WriteMultiBuffer(b MultiBuffer) error {
if !w.buffered {
return w.writer.WriteMultiBuffer(b)
}
defer b.Release()
for !b.IsEmpty() {
if err := w.buffer.AppendSupplier(ReadFrom(&b)); err != nil {
return err
}
if w.buffer.IsFull() {
if err := w.Flush(); err != nil {
return err
}
}
}
return nil
}
// Flush flushes buffered content into underlying writer.
func (w *BufferedWriter) Flush() error {
if !w.buffer.IsEmpty() {
if err := w.writer.WriteMultiBuffer(NewMultiBufferValue(w.buffer)); err != nil {
return err
}
if w.buffered {
w.buffer = New()
} else {
w.buffer = nil
}
} }
return nil return nil
} }
func (w *BufferedWriter) SetBuffered(f bool) error {
w.buffered = f
if !f {
return w.Flush()
}
return nil
}
// ReadFrom implements io.ReaderFrom.
func (w *BufferedWriter) ReadFrom(reader io.Reader) (int64, error) {
if err := w.SetBuffered(false); err != nil {
return 0, err
}
var sc SizeCounter
err := Copy(NewReader(reader), w, CountSize(&sc))
return sc.Size, err
}
type seqWriter struct { type seqWriter struct {
writer io.Writer writer io.Writer
} }
func (w *seqWriter) Write(mb MultiBuffer) error { func (w *seqWriter) WriteMultiBuffer(mb MultiBuffer) error {
defer mb.Release() defer mb.Release()
for _, b := range mb { for _, b := range mb {
@ -61,54 +173,38 @@ func (w *seqWriter) Write(mb MultiBuffer) error {
return nil return nil
} }
var (
_ MultiBufferWriter = (*bytesToBufferWriter)(nil)
)
type bytesToBufferWriter struct {
writer Writer
}
// Write implements io.Writer.
func (w *bytesToBufferWriter) Write(payload []byte) (int, error) {
mb := NewMultiBuffer()
mb.Write(payload)
if err := w.writer.Write(mb); err != nil {
return 0, err
}
return len(payload), nil
}
func (w *bytesToBufferWriter) WriteMultiBuffer(mb MultiBuffer) error {
return w.writer.Write(mb)
}
func (w *bytesToBufferWriter) ReadFrom(reader io.Reader) (int64, error) {
mbReader := NewReader(reader)
totalBytes := int64(0)
eof := false
for !eof {
mb, err := mbReader.Read()
if err == io.EOF {
eof = true
} else if err != nil {
return totalBytes, err
}
totalBytes += int64(mb.Len())
if err := w.writer.Write(mb); err != nil {
return totalBytes, err
}
}
return totalBytes, nil
}
type noOpWriter struct{} type noOpWriter struct{}
func (noOpWriter) Write(b MultiBuffer) error { func (noOpWriter) WriteMultiBuffer(b MultiBuffer) error {
b.Release() b.Release()
return nil return nil
} }
func (noOpWriter) Write(b []byte) (int, error) {
return len(b), nil
}
func (noOpWriter) ReadFrom(reader io.Reader) (int64, error) {
b := New()
defer b.Release()
totalBytes := int64(0)
for {
err := b.Reset(ReadFrom(reader))
totalBytes += int64(b.Len())
if err != nil {
if errors.Cause(err) == io.EOF {
return totalBytes, nil
}
return totalBytes, err
}
}
}
var ( var (
// Discard is a Writer that swallows all contents written in.
Discard Writer = noOpWriter{} Discard Writer = noOpWriter{}
// DiscardBytes is an io.Writer that swallows all contents written in.
DiscardBytes io.Writer = noOpWriter{}
) )

View File

@ -9,37 +9,67 @@ import (
"context" "context"
"io" "io"
"v2ray.com/core/common"
. "v2ray.com/core/common/buf" . "v2ray.com/core/common/buf"
"v2ray.com/core/testing/assert"
"v2ray.com/core/transport/ray" "v2ray.com/core/transport/ray"
. "v2ray.com/ext/assert"
) )
func TestWriter(t *testing.T) { func TestWriter(t *testing.T) {
assert := assert.On(t) assert := With(t)
lb := New() lb := New()
assert.Error(lb.AppendSupplier(ReadFrom(rand.Reader))).IsNil() assert(lb.AppendSupplier(ReadFrom(rand.Reader)), IsNil)
expectedBytes := append([]byte(nil), lb.Bytes()...) expectedBytes := append([]byte(nil), lb.Bytes()...)
writeBuffer := bytes.NewBuffer(make([]byte, 0, 1024*1024)) writeBuffer := bytes.NewBuffer(make([]byte, 0, 1024*1024))
writer := NewWriter(NewBufferedWriter(writeBuffer)) writer := NewBufferedWriter(NewWriter(writeBuffer))
err := writer.Write(NewMultiBufferValue(lb)) writer.SetBuffered(false)
assert.Error(err).IsNil() err := writer.WriteMultiBuffer(NewMultiBufferValue(lb))
assert.Bytes(expectedBytes).Equals(writeBuffer.Bytes()) assert(err, IsNil)
assert(writer.Flush(), IsNil)
assert(expectedBytes, Equals, writeBuffer.Bytes())
} }
func TestBytesWriterReadFrom(t *testing.T) { func TestBytesWriterReadFrom(t *testing.T) {
assert := assert.On(t) assert := With(t)
cache := ray.NewStream(context.Background()) cache := ray.NewStream(context.Background())
reader := bufio.NewReader(io.LimitReader(rand.Reader, 8192)) const size = 50000
_, err := reader.WriteTo(ToBytesWriter(cache)) reader := bufio.NewReader(io.LimitReader(rand.Reader, size))
assert.Error(err).IsNil() writer := NewBufferedWriter(cache)
writer.SetBuffered(false)
nBytes, err := reader.WriteTo(writer)
assert(nBytes, Equals, int64(size))
assert(err, IsNil)
mb, err := cache.Read() mb, err := cache.ReadMultiBuffer()
assert.Error(err).IsNil() assert(err, IsNil)
assert.Int(mb.Len()).Equals(8192) assert(mb.Len(), Equals, size)
assert.Int(len(mb)).Equals(4) }
func TestDiscardBytes(t *testing.T) {
assert := With(t)
b := New()
common.Must(b.Reset(ReadFullFrom(rand.Reader, Size)))
nBytes, err := io.Copy(DiscardBytes, b)
assert(nBytes, Equals, int64(Size))
assert(err, IsNil)
}
func TestDiscardBytesMultiBuffer(t *testing.T) {
assert := With(t)
const size = 10240*1024 + 1
buffer := bytes.NewBuffer(make([]byte, 0, size))
common.Must2(buffer.ReadFrom(io.LimitReader(rand.Reader, size)))
r := NewReader(buffer)
nBytes, err := io.Copy(DiscardBytes, NewBufferedReader(r))
assert(nBytes, Equals, int64(size))
assert(err, IsNil)
} }

View File

@ -11,6 +11,7 @@ func Must(err error) {
} }
} }
// Must2 panics if the second parameter is not nil.
func Must2(v interface{}, err error) { func Must2(v interface{}, err error) {
if err != nil { if err != nil {
panic(err) panic(err)

View File

@ -29,6 +29,26 @@ func (v StaticBytesGenerator) Next() []byte {
return v.Content return v.Content
} }
type IncreasingAEADNonceGenerator struct {
nonce []byte
}
func NewIncreasingAEADNonceGenerator() *IncreasingAEADNonceGenerator {
return &IncreasingAEADNonceGenerator{
nonce: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF},
}
}
func (g *IncreasingAEADNonceGenerator) Next() []byte {
for i := range g.nonce {
g.nonce[i]++
if g.nonce[i] != 0 {
break
}
}
return g.nonce
}
type Authenticator interface { type Authenticator interface {
NonceSize() int NonceSize() int
Overhead() int Overhead() int
@ -48,7 +68,10 @@ func (v *AEADAuthenticator) Open(dst, cipherText []byte) ([]byte, error) {
return nil, newError("invalid AEAD nonce size: ", len(iv)) return nil, newError("invalid AEAD nonce size: ", len(iv))
} }
additionalData := v.AdditionalDataGenerator.Next() var additionalData []byte
if v.AdditionalDataGenerator != nil {
additionalData = v.AdditionalDataGenerator.Next()
}
return v.AEAD.Open(dst, iv, cipherText, additionalData) return v.AEAD.Open(dst, iv, cipherText, additionalData)
} }
@ -58,7 +81,10 @@ func (v *AEADAuthenticator) Seal(dst, plainText []byte) ([]byte, error) {
return nil, newError("invalid AEAD nonce size: ", len(iv)) return nil, newError("invalid AEAD nonce size: ", len(iv))
} }
additionalData := v.AdditionalDataGenerator.Next() var additionalData []byte
if v.AdditionalDataGenerator != nil {
additionalData = v.AdditionalDataGenerator.Next()
}
return v.AEAD.Seal(dst, iv, plainText, additionalData), nil return v.AEAD.Seal(dst, iv, plainText, additionalData), nil
} }
@ -93,7 +119,12 @@ func (r *AuthenticationReader) readSize() error {
sizeBytes := r.sizeParser.SizeBytes() sizeBytes := r.sizeParser.SizeBytes()
if r.buffer.Len() < sizeBytes { if r.buffer.Len() < sizeBytes {
r.buffer.Reset(buf.ReadFrom(r.buffer)) if r.buffer.IsEmpty() {
r.buffer.Clear()
} else {
common.Must(r.buffer.Reset(buf.ReadFrom(r.buffer)))
}
delta := sizeBytes - r.buffer.Len() delta := sizeBytes - r.buffer.Len()
if err := r.buffer.AppendSupplier(buf.ReadAtLeastFrom(r.reader, delta)); err != nil { if err := r.buffer.AppendSupplier(buf.ReadAtLeastFrom(r.reader, delta)); err != nil {
return err return err
@ -146,18 +177,18 @@ func (r *AuthenticationReader) readChunk(waitForData bool) ([]byte, error) {
return b, nil return b, nil
} }
func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) { func (r *AuthenticationReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
b, err := r.readChunk(true) b, err := r.readChunk(true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
mb := buf.NewMultiBuffer() var mb buf.MultiBuffer
if r.transferType == protocol.TransferTypeStream { if r.transferType == protocol.TransferTypeStream {
mb.Write(b) mb.Write(b)
} else { } else {
var bb *buf.Buffer var bb *buf.Buffer
if len(b) < buf.Size { if len(b) <= buf.Size {
bb = buf.New() bb = buf.New()
} else { } else {
bb = buf.NewLocal(len(b)) bb = buf.NewLocal(len(b))
@ -175,7 +206,7 @@ func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) {
mb.Write(b) mb.Write(b)
} else { } else {
var bb *buf.Buffer var bb *buf.Buffer
if len(b) < buf.Size { if len(b) <= buf.Size {
bb = buf.New() bb = buf.New()
} else { } else {
bb = buf.NewLocal(len(b)) bb = buf.NewLocal(len(b))
@ -190,79 +221,92 @@ func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) {
type AuthenticationWriter struct { type AuthenticationWriter struct {
auth Authenticator auth Authenticator
buffer []byte writer buf.Writer
payload []byte
writer *buf.BufferedWriter
sizeParser ChunkSizeEncoder sizeParser ChunkSizeEncoder
transferType protocol.TransferType transferType protocol.TransferType
} }
func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, writer io.Writer, transferType protocol.TransferType) *AuthenticationWriter { func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, writer io.Writer, transferType protocol.TransferType) *AuthenticationWriter {
const payloadSize = 1024
return &AuthenticationWriter{ return &AuthenticationWriter{
auth: auth, auth: auth,
buffer: make([]byte, payloadSize+sizeParser.SizeBytes()+auth.Overhead()), writer: buf.NewWriter(writer),
payload: make([]byte, payloadSize),
writer: buf.NewBufferedWriterSize(writer, readerBufferSize),
sizeParser: sizeParser, sizeParser: sizeParser,
transferType: transferType, transferType: transferType,
} }
} }
func (w *AuthenticationWriter) append(b []byte) error { func (w *AuthenticationWriter) seal(b *buf.Buffer) (*buf.Buffer, error) {
encryptedSize := len(b) + w.auth.Overhead() encryptedSize := b.Len() + w.auth.Overhead()
buffer := w.sizeParser.Encode(uint16(encryptedSize), w.buffer[:0])
buffer, err := w.auth.Seal(buffer, b) eb := buf.New()
if err != nil { common.Must(eb.Reset(func(bb []byte) (int, error) {
return err w.sizeParser.Encode(uint16(encryptedSize), bb[:0])
return w.sizeParser.SizeBytes(), nil
}))
if err := eb.AppendSupplier(func(bb []byte) (int, error) {
_, err := w.auth.Seal(bb[:0], b.Bytes())
return encryptedSize, err
}); err != nil {
eb.Release()
return nil, err
} }
if _, err := w.writer.Write(buffer); err != nil { return eb, nil
return err
}
return nil
} }
func (w *AuthenticationWriter) writeStream(mb buf.MultiBuffer) error { func (w *AuthenticationWriter) writeStream(mb buf.MultiBuffer) error {
defer mb.Release() defer mb.Release()
payloadSize := buf.Size - w.auth.Overhead() - w.sizeParser.SizeBytes()
mb2Write := buf.NewMultiBufferCap(len(mb) + 10)
for { for {
n, _ := mb.Read(w.payload) b := buf.New()
if err := w.append(w.payload[:n]); err != nil { common.Must(b.Reset(func(bb []byte) (int, error) {
return mb.Read(bb[:payloadSize])
}))
eb, err := w.seal(b)
b.Release()
if err != nil {
mb2Write.Release()
return err return err
} }
mb2Write.Append(eb)
if mb.IsEmpty() { if mb.IsEmpty() {
break break
} }
} }
return w.writer.Flush() return w.writer.WriteMultiBuffer(mb2Write)
} }
func (w *AuthenticationWriter) writePacket(mb buf.MultiBuffer) error { func (w *AuthenticationWriter) writePacket(mb buf.MultiBuffer) error {
defer mb.Release() defer mb.Release()
mb2Write := buf.NewMultiBufferCap(len(mb) * 2)
for { for {
b := mb.SplitFirst() b := mb.SplitFirst()
if b == nil { if b == nil {
b = buf.New() b = buf.New()
} }
if err := w.append(b.Bytes()); err != nil { eb, err := w.seal(b)
b.Release() b.Release()
if err != nil {
mb2Write.Release()
return err return err
} }
b.Release() mb2Write.Append(eb)
if mb.IsEmpty() { if mb.IsEmpty() {
break break
} }
} }
return w.writer.Flush() return w.writer.WriteMultiBuffer(mb2Write)
} }
func (w *AuthenticationWriter) Write(mb buf.MultiBuffer) error { func (w *AuthenticationWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
if w.transferType == protocol.TransferTypeStream { if w.transferType == protocol.TransferTypeStream {
return w.writeStream(mb) return w.writeStream(mb)
} }

View File

@ -10,19 +10,19 @@ import (
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
. "v2ray.com/core/common/crypto" . "v2ray.com/core/common/crypto"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func TestAuthenticationReaderWriter(t *testing.T) { func TestAuthenticationReaderWriter(t *testing.T) {
assert := assert.On(t) assert := With(t)
key := make([]byte, 16) key := make([]byte, 16)
rand.Read(key) rand.Read(key)
block, err := aes.NewCipher(key) block, err := aes.NewCipher(key)
assert.Error(err).IsNil() assert(err, IsNil)
aead, err := cipher.NewGCM(block) aead, err := cipher.NewGCM(block)
assert.Error(err).IsNil() assert(err, IsNil)
rawPayload := make([]byte, 8192*10) rawPayload := make([]byte, 8192*10)
rand.Read(rawPayload) rand.Read(rawPayload)
@ -42,10 +42,10 @@ func TestAuthenticationReaderWriter(t *testing.T) {
AdditionalDataGenerator: &NoOpBytesGenerator{}, AdditionalDataGenerator: &NoOpBytesGenerator{},
}, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream) }, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream)
assert.Error(writer.Write(buf.NewMultiBufferValue(payload))).IsNil() assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(payload)), IsNil)
assert.Int(cache.Len()).Equals(83360) assert(cache.Len(), Equals, 82658)
assert.Error(writer.Write(buf.NewMultiBuffer())).IsNil() assert(writer.WriteMultiBuffer(buf.MultiBuffer{}), IsNil)
assert.Error(err).IsNil() assert(err, IsNil)
reader := NewAuthenticationReader(&AEADAuthenticator{ reader := NewAuthenticationReader(&AEADAuthenticator{
AEAD: aead, AEAD: aead,
@ -55,33 +55,33 @@ func TestAuthenticationReaderWriter(t *testing.T) {
AdditionalDataGenerator: &NoOpBytesGenerator{}, AdditionalDataGenerator: &NoOpBytesGenerator{},
}, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream) }, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream)
mb := buf.NewMultiBuffer() var mb buf.MultiBuffer
for mb.Len() < len(rawPayload) { for mb.Len() < len(rawPayload) {
mb2, err := reader.Read() mb2, err := reader.ReadMultiBuffer()
assert.Error(err).IsNil() assert(err, IsNil)
mb.AppendMulti(mb2) mb.AppendMulti(mb2)
} }
mbContent := make([]byte, 8192*10) mbContent := make([]byte, 8192*10)
mb.Read(mbContent) mb.Read(mbContent)
assert.Bytes(mbContent).Equals(rawPayload) assert(mbContent, Equals, rawPayload)
_, err = reader.Read() _, err = reader.ReadMultiBuffer()
assert.Error(err).Equals(io.EOF) assert(err, Equals, io.EOF)
} }
func TestAuthenticationReaderWriterPacket(t *testing.T) { func TestAuthenticationReaderWriterPacket(t *testing.T) {
assert := assert.On(t) assert := With(t)
key := make([]byte, 16) key := make([]byte, 16)
rand.Read(key) rand.Read(key)
block, err := aes.NewCipher(key) block, err := aes.NewCipher(key)
assert.Error(err).IsNil() assert(err, IsNil)
aead, err := cipher.NewGCM(block) aead, err := cipher.NewGCM(block)
assert.Error(err).IsNil() assert(err, IsNil)
cache := buf.NewLocal(1024) cache := buf.NewLocal(1024)
iv := make([]byte, 12) iv := make([]byte, 12)
@ -95,7 +95,7 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) {
AdditionalDataGenerator: &NoOpBytesGenerator{}, AdditionalDataGenerator: &NoOpBytesGenerator{},
}, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket) }, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket)
payload := buf.NewMultiBuffer() var payload buf.MultiBuffer
pb1 := buf.New() pb1 := buf.New()
pb1.Append([]byte("abcd")) pb1.Append([]byte("abcd"))
payload.Append(pb1) payload.Append(pb1)
@ -104,10 +104,10 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) {
pb2.Append([]byte("efgh")) pb2.Append([]byte("efgh"))
payload.Append(pb2) payload.Append(pb2)
assert.Error(writer.Write(payload)).IsNil() assert(writer.WriteMultiBuffer(payload), IsNil)
assert.Int(cache.Len()).GreaterThan(0) assert(cache.Len(), GreaterThan, 0)
assert.Error(writer.Write(buf.NewMultiBuffer())).IsNil() assert(writer.WriteMultiBuffer(buf.MultiBuffer{}), IsNil)
assert.Error(err).IsNil() assert(err, IsNil)
reader := NewAuthenticationReader(&AEADAuthenticator{ reader := NewAuthenticationReader(&AEADAuthenticator{
AEAD: aead, AEAD: aead,
@ -117,15 +117,15 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) {
AdditionalDataGenerator: &NoOpBytesGenerator{}, AdditionalDataGenerator: &NoOpBytesGenerator{},
}, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket) }, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket)
mb, err := reader.Read() mb, err := reader.ReadMultiBuffer()
assert.Error(err).IsNil() assert(err, IsNil)
b1 := mb.SplitFirst() b1 := mb.SplitFirst()
assert.String(b1.String()).Equals("abcd") assert(b1.String(), Equals, "abcd")
b2 := mb.SplitFirst() b2 := mb.SplitFirst()
assert.String(b2.String()).Equals("efgh") assert(b2.String(), Equals, "efgh")
assert.Bool(mb.IsEmpty()).IsTrue() assert(mb.IsEmpty(), IsTrue)
_, err = reader.Read() _, err = reader.ReadMultiBuffer()
assert.Error(err).Equals(io.EOF) assert(err, Equals, io.EOF)
} }

View File

@ -7,7 +7,7 @@ import (
"v2ray.com/core/common" "v2ray.com/core/common"
. "v2ray.com/core/common/crypto" . "v2ray.com/core/common/crypto"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func mustDecodeHex(s string) []byte { func mustDecodeHex(s string) []byte {
@ -17,7 +17,7 @@ func mustDecodeHex(s string) []byte {
} }
func TestChaCha20Stream(t *testing.T) { func TestChaCha20Stream(t *testing.T) {
assert := assert.On(t) assert := With(t)
var cases = []struct { var cases = []struct {
key []byte key []byte
@ -51,12 +51,12 @@ func TestChaCha20Stream(t *testing.T) {
input := make([]byte, len(c.output)) input := make([]byte, len(c.output))
actualOutout := make([]byte, len(c.output)) actualOutout := make([]byte, len(c.output))
s.XORKeyStream(actualOutout, input) s.XORKeyStream(actualOutout, input)
assert.Bytes(c.output).Equals(actualOutout) assert(c.output, Equals, actualOutout)
} }
} }
func TestChaCha20Decoding(t *testing.T) { func TestChaCha20Decoding(t *testing.T) {
assert := assert.On(t) assert := With(t)
key := make([]byte, 32) key := make([]byte, 32)
rand.Read(key) rand.Read(key)
@ -72,5 +72,5 @@ func TestChaCha20Decoding(t *testing.T) {
stream2 := NewChaCha20Stream(key, iv) stream2 := NewChaCha20Stream(key, iv)
stream2.XORKeyStream(x, x) stream2.XORKeyStream(x, x)
assert.Bytes(x).Equals(payload) assert(x, Equals, payload)
} }

View File

@ -3,15 +3,18 @@ package crypto
import ( import (
"io" "io"
"v2ray.com/core/common"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/serial" "v2ray.com/core/common/serial"
) )
// ChunkSizeDecoder is an utility class to decode size value from bytes.
type ChunkSizeDecoder interface { type ChunkSizeDecoder interface {
SizeBytes() int SizeBytes() int
Decode([]byte) (uint16, error) Decode([]byte) (uint16, error)
} }
// ChunkSizeEncoder is an utility class to encode size value into bytes.
type ChunkSizeEncoder interface { type ChunkSizeEncoder interface {
SizeBytes() int SizeBytes() int
Encode(uint16, []byte) []byte Encode(uint16, []byte) []byte
@ -31,50 +34,53 @@ func (PlainChunkSizeParser) Decode(b []byte) (uint16, error) {
return serial.BytesToUint16(b), nil return serial.BytesToUint16(b), nil
} }
type AEADChunkSizeParser struct {
Auth *AEADAuthenticator
}
func (p *AEADChunkSizeParser) SizeBytes() int {
return 2 + p.Auth.Overhead()
}
func (p *AEADChunkSizeParser) Encode(size uint16, b []byte) []byte {
b = serial.Uint16ToBytes(size-uint16(p.Auth.Overhead()), b)
b, err := p.Auth.Seal(b[:0], b)
common.Must(err)
return b
}
func (p *AEADChunkSizeParser) Decode(b []byte) (uint16, error) {
b, err := p.Auth.Open(b[:0], b)
if err != nil {
return 0, err
}
return serial.BytesToUint16(b) + uint16(p.Auth.Overhead()), nil
}
type ChunkStreamReader struct { type ChunkStreamReader struct {
sizeDecoder ChunkSizeDecoder sizeDecoder ChunkSizeDecoder
reader buf.Reader reader *buf.BufferedReader
buffer []byte buffer []byte
leftOver buf.MultiBuffer
leftOverSize int leftOverSize int
} }
func NewChunkStreamReader(sizeDecoder ChunkSizeDecoder, reader io.Reader) *ChunkStreamReader { func NewChunkStreamReader(sizeDecoder ChunkSizeDecoder, reader io.Reader) *ChunkStreamReader {
return &ChunkStreamReader{ return &ChunkStreamReader{
sizeDecoder: sizeDecoder, sizeDecoder: sizeDecoder,
reader: buf.NewReader(reader), reader: buf.NewBufferedReader(buf.NewReader(reader)),
buffer: make([]byte, sizeDecoder.SizeBytes()), buffer: make([]byte, sizeDecoder.SizeBytes()),
} }
} }
func (r *ChunkStreamReader) readAtLeast(size int) error {
mb := r.leftOver
r.leftOver = nil
for mb.Len() < size {
extra, err := r.reader.Read()
if err != nil {
mb.Release()
return err
}
mb.AppendMulti(extra)
}
r.leftOver = mb
return nil
}
func (r *ChunkStreamReader) readSize() (uint16, error) { func (r *ChunkStreamReader) readSize() (uint16, error) {
if r.sizeDecoder.SizeBytes() > r.leftOver.Len() { if _, err := io.ReadFull(r.reader, r.buffer); err != nil {
if err := r.readAtLeast(r.sizeDecoder.SizeBytes() - r.leftOver.Len()); err != nil {
return 0, err return 0, err
} }
}
r.leftOver.Read(r.buffer)
return r.sizeDecoder.Decode(r.buffer) return r.sizeDecoder.Decode(r.buffer)
} }
func (r *ChunkStreamReader) Read() (buf.MultiBuffer, error) { func (r *ChunkStreamReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
size := r.leftOverSize size := r.leftOverSize
if size == 0 { if size == 0 {
nextSize, err := r.readSize() nextSize, err := r.readSize()
@ -86,30 +92,15 @@ func (r *ChunkStreamReader) Read() (buf.MultiBuffer, error) {
} }
size = int(nextSize) size = int(nextSize)
} }
r.leftOverSize = size
if r.leftOver.IsEmpty() { mb, err := r.reader.ReadAtMost(size)
if err := r.readAtLeast(1); err != nil { if !mb.IsEmpty() {
r.leftOverSize -= mb.Len()
return mb, nil
}
return nil, err return nil, err
} }
}
if size >= r.leftOver.Len() {
mb := r.leftOver
r.leftOverSize = size - r.leftOver.Len()
r.leftOver = nil
return mb, nil
}
mb := r.leftOver.SliceBySize(size)
if mb.Len() != size {
b := buf.New()
b.AppendSupplier(buf.ReadFullFrom(&r.leftOver, size-mb.Len()))
mb.Append(b)
}
r.leftOverSize = 0
return mb, nil
}
type ChunkStreamWriter struct { type ChunkStreamWriter struct {
sizeEncoder ChunkSizeEncoder sizeEncoder ChunkSizeEncoder
@ -123,18 +114,19 @@ func NewChunkStreamWriter(sizeEncoder ChunkSizeEncoder, writer io.Writer) *Chunk
} }
} }
func (w *ChunkStreamWriter) Write(mb buf.MultiBuffer) error { func (w *ChunkStreamWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
mb2Write := buf.NewMultiBuffer()
const sliceSize = 8192 const sliceSize = 8192
mbLen := mb.Len()
mb2Write := buf.NewMultiBufferCap(mbLen/buf.Size + mbLen/sliceSize + 2)
for { for {
slice := mb.SliceBySize(sliceSize) slice := mb.SliceBySize(sliceSize)
b := buf.New() b := buf.New()
b.AppendSupplier(func(buffer []byte) (int, error) { common.Must(b.Reset(func(buffer []byte) (int, error) {
w.sizeEncoder.Encode(uint16(slice.Len()), buffer[:0]) w.sizeEncoder.Encode(uint16(slice.Len()), buffer[:0])
return w.sizeEncoder.SizeBytes(), nil return w.sizeEncoder.SizeBytes(), nil
}) }))
mb2Write.Append(b) mb2Write.Append(b)
mb2Write.AppendMulti(slice) mb2Write.AppendMulti(slice)
@ -143,5 +135,5 @@ func (w *ChunkStreamWriter) Write(mb buf.MultiBuffer) error {
} }
} }
return w.writer.Write(mb2Write) return w.writer.WriteMultiBuffer(mb2Write)
} }

View File

@ -6,11 +6,11 @@ import (
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
. "v2ray.com/core/common/crypto" . "v2ray.com/core/common/crypto"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func TestChunkStreamIO(t *testing.T) { func TestChunkStreamIO(t *testing.T) {
assert := assert.On(t) assert := With(t)
cache := buf.NewLocal(8192) cache := buf.NewLocal(8192)
@ -19,26 +19,26 @@ func TestChunkStreamIO(t *testing.T) {
b := buf.New() b := buf.New()
b.AppendBytes('a', 'b', 'c', 'd') b.AppendBytes('a', 'b', 'c', 'd')
assert.Error(writer.Write(buf.NewMultiBufferValue(b))).IsNil() assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)), IsNil)
b = buf.New() b = buf.New()
b.AppendBytes('e', 'f', 'g') b.AppendBytes('e', 'f', 'g')
assert.Error(writer.Write(buf.NewMultiBufferValue(b))).IsNil() assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)), IsNil)
assert.Error(writer.Write(buf.NewMultiBuffer())).IsNil() assert(writer.WriteMultiBuffer(buf.MultiBuffer{}), IsNil)
assert.Int(cache.Len()).Equals(13) assert(cache.Len(), Equals, 13)
mb, err := reader.Read() mb, err := reader.ReadMultiBuffer()
assert.Error(err).IsNil() assert(err, IsNil)
assert.Int(mb.Len()).Equals(4) assert(mb.Len(), Equals, 4)
assert.Bytes(mb[0].Bytes()).Equals([]byte("abcd")) assert(mb[0].Bytes(), Equals, []byte("abcd"))
mb, err = reader.Read() mb, err = reader.ReadMultiBuffer()
assert.Error(err).IsNil() assert(err, IsNil)
assert.Int(mb.Len()).Equals(3) assert(mb.Len(), Equals, 3)
assert.Bytes(mb[0].Bytes()).Equals([]byte("efg")) assert(mb[0].Bytes(), Equals, []byte("efg"))
_, err = reader.Read() _, err = reader.ReadMultiBuffer()
assert.Error(err).Equals(io.EOF) assert(err, Equals, io.EOF)
} }

View File

@ -28,7 +28,7 @@ func (r *CryptionReader) Read(data []byte) (int, error) {
} }
var ( var (
_ buf.MultiBufferWriter = (*CryptionWriter)(nil) _ buf.Writer = (*CryptionWriter)(nil)
) )
type CryptionWriter struct { type CryptionWriter struct {
@ -51,6 +51,8 @@ func (w *CryptionWriter) Write(data []byte) (int, error) {
} }
func (w *CryptionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { func (w *CryptionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
defer mb.Release()
bs := mb.ToNetBuffers() bs := mb.ToNetBuffers()
for _, b := range bs { for _, b := range bs {
w.stream.XORKeyStream(b, b) w.stream.XORKeyStream(b, b)

View File

@ -5,29 +5,29 @@ import (
"testing" "testing"
. "v2ray.com/core/common/errors" . "v2ray.com/core/common/errors"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func TestError(t *testing.T) { func TestError(t *testing.T) {
assert := assert.On(t) assert := With(t)
err := New("TestError") err := New("TestError")
assert.Bool(GetSeverity(err) == SeverityInfo).IsTrue() assert(GetSeverity(err), Equals, SeverityInfo)
err = New("TestError2").Base(io.EOF) err = New("TestError2").Base(io.EOF)
assert.Bool(GetSeverity(err) == SeverityInfo).IsTrue() assert(GetSeverity(err), Equals, SeverityInfo)
err = New("TestError3").Base(io.EOF).AtWarning() err = New("TestError3").Base(io.EOF).AtWarning()
assert.Bool(GetSeverity(err) == SeverityWarning).IsTrue() assert(GetSeverity(err), Equals, SeverityWarning)
err = New("TestError4").Base(io.EOF).AtWarning() err = New("TestError4").Base(io.EOF).AtWarning()
err = New("TestError5").Base(err) err = New("TestError5").Base(err)
assert.Bool(GetSeverity(err) == SeverityWarning).IsTrue() assert(GetSeverity(err), Equals, SeverityWarning)
assert.String(err.Error()).Contains("EOF") assert(err.Error(), HasSubstring, "EOF")
} }
func TestErrorMessage(t *testing.T) { func TestErrorMessage(t *testing.T) {
assert := assert.On(t) assert := With(t)
data := []struct { data := []struct {
err error err error
@ -44,6 +44,6 @@ func TestErrorMessage(t *testing.T) {
} }
for _, d := range data { for _, d := range data {
assert.String(d.err.Error()).Equals(d.msg) assert(d.err.Error(), Equals, d.msg)
} }
} }

46
common/event/event.go Normal file
View File

@ -0,0 +1,46 @@
package event
import "sync"
type Event uint16
type Handler func(data interface{}) error
type Registry interface {
On(Event, Handler)
}
type Listener struct {
sync.RWMutex
events map[Event][]Handler
}
func (l *Listener) On(e Event, h Handler) {
l.Lock()
defer l.Unlock()
if l.events == nil {
l.events = make(map[Event][]Handler)
}
handlers := l.events[e]
handlers = append(handlers, h)
l.events[e] = handlers
}
func (l *Listener) Fire(e Event, data interface{}) error {
l.RLock()
defer l.RUnlock()
if l.events == nil {
return nil
}
for _, h := range l.events[e] {
if err := h(data); err != nil {
return err
}
}
return nil
}

View File

@ -73,6 +73,12 @@ type Address interface {
// ParseAddress parses a string into an Address. The return value will be an IPAddress when // ParseAddress parses a string into an Address. The return value will be an IPAddress when
// the string is in the form of IPv4 or IPv6 address, or a DomainAddress otherwise. // the string is in the form of IPv4 or IPv6 address, or a DomainAddress otherwise.
func ParseAddress(addr string) Address { func ParseAddress(addr string) Address {
// Handle IPv6 address in form as "[2001:4860:0:2001::68]"
lenAddr := len(addr)
if lenAddr > 0 && addr[0] == '[' && addr[lenAddr-1] == ']' {
addr = addr[1 : lenAddr-1]
}
ip := net.ParseIP(addr) ip := net.ParseIP(addr)
if ip != nil { if ip != nil {
return IPAddress(ip) return IPAddress(ip)

View File

@ -5,24 +5,25 @@ import (
"testing" "testing"
. "v2ray.com/core/common/net" . "v2ray.com/core/common/net"
"v2ray.com/core/testing/assert" . "v2ray.com/core/common/net/testing"
. "v2ray.com/ext/assert"
) )
func TestIPv4Address(t *testing.T) { func TestIPv4Address(t *testing.T) {
assert := assert.On(t) assert := With(t)
ip := []byte{byte(1), byte(2), byte(3), byte(4)} ip := []byte{byte(1), byte(2), byte(3), byte(4)}
addr := IPAddress(ip) addr := IPAddress(ip)
assert.Address(addr).IsIPv4() assert(addr, IsIPv4)
assert.Address(addr).IsNotIPv6() assert(addr, Not(IsIPv6))
assert.Address(addr).IsNotDomain() assert(addr, Not(IsDomain))
assert.Bytes(addr.IP()).Equals(ip) assert([]byte(addr.IP()), Equals, ip)
assert.Address(addr).EqualsString("1.2.3.4") assert(addr.String(), Equals, "1.2.3.4")
} }
func TestIPv6Address(t *testing.T) { func TestIPv6Address(t *testing.T) {
assert := assert.On(t) assert := With(t)
ip := []byte{ ip := []byte{
byte(1), byte(2), byte(3), byte(4), byte(1), byte(2), byte(3), byte(4),
@ -32,15 +33,15 @@ func TestIPv6Address(t *testing.T) {
} }
addr := IPAddress(ip) addr := IPAddress(ip)
assert.Address(addr).IsIPv6() assert(addr, IsIPv6)
assert.Address(addr).IsNotIPv4() assert(addr, Not(IsIPv4))
assert.Address(addr).IsNotDomain() assert(addr, Not(IsDomain))
assert.IP(addr.IP()).Equals(net.IP(ip)) assert(addr.IP(), Equals, net.IP(ip))
assert.Address(addr).EqualsString("[102:304:102:304:102:304:102:304]") assert(addr.String(), Equals, "[102:304:102:304:102:304:102:304]")
} }
func TestIPv4Asv6(t *testing.T) { func TestIPv4Asv6(t *testing.T) {
assert := assert.On(t) assert := With(t)
ip := []byte{ ip := []byte{
byte(0), byte(0), byte(0), byte(0), byte(0), byte(0), byte(0), byte(0),
byte(0), byte(0), byte(0), byte(0), byte(0), byte(0), byte(0), byte(0),
@ -48,27 +49,55 @@ func TestIPv4Asv6(t *testing.T) {
byte(1), byte(2), byte(3), byte(4), byte(1), byte(2), byte(3), byte(4),
} }
addr := IPAddress(ip) addr := IPAddress(ip)
assert.Address(addr).EqualsString("1.2.3.4") assert(addr.String(), Equals, "1.2.3.4")
} }
func TestDomainAddress(t *testing.T) { func TestDomainAddress(t *testing.T) {
assert := assert.On(t) assert := With(t)
domain := "v2ray.com" domain := "v2ray.com"
addr := DomainAddress(domain) addr := DomainAddress(domain)
assert.Address(addr).IsDomain() assert(addr, IsDomain)
assert.Address(addr).IsNotIPv6() assert(addr, Not(IsIPv6))
assert.Address(addr).IsNotIPv4() assert(addr, Not(IsIPv4))
assert.String(addr.Domain()).Equals(domain) assert(addr.Domain(), Equals, domain)
assert.Address(addr).EqualsString("v2ray.com") assert(addr.String(), Equals, "v2ray.com")
} }
func TestNetIPv4Address(t *testing.T) { func TestNetIPv4Address(t *testing.T) {
assert := assert.On(t) assert := With(t)
ip := net.IPv4(1, 2, 3, 4) ip := net.IPv4(1, 2, 3, 4)
addr := IPAddress(ip) addr := IPAddress(ip)
assert.Address(addr).IsIPv4() assert(addr, IsIPv4)
assert.Address(addr).EqualsString("1.2.3.4") assert(addr.String(), Equals, "1.2.3.4")
}
func TestParseIPv6Address(t *testing.T) {
assert := With(t)
ip := ParseAddress("[2001:4860:0:2001::68]")
assert(ip, IsIPv6)
assert(ip.String(), Equals, "[2001:4860:0:2001::68]")
ip = ParseAddress("[::ffff:123.151.71.143]")
assert(ip, IsIPv4)
assert(ip.String(), Equals, "123.151.71.143")
}
func TestInvalidAddressConvertion(t *testing.T) {
assert := With(t)
assert(func() { ParseAddress("8.8.8.8").Domain() }, Panics)
assert(func() { ParseAddress("2001:4860:0:2001::68").Domain() }, Panics)
assert(func() { ParseAddress("v2ray.com").IP() }, Panics)
}
func TestIPOrDomain(t *testing.T) {
assert := With(t)
assert(NewIPOrDomain(ParseAddress("v2ray.com")).AsAddress(), Equals, ParseAddress("v2ray.com"))
assert(NewIPOrDomain(ParseAddress("8.8.8.8")).AsAddress(), Equals, ParseAddress("8.8.8.8"))
assert(NewIPOrDomain(ParseAddress("2001:4860:0:2001::68")).AsAddress(), Equals, ParseAddress("2001:4860:0:2001::68"))
} }

View File

@ -41,14 +41,17 @@ func UDPDestination(address Address, port Port) Destination {
} }
} }
// NetAddr returns the network address in this Destination in string form.
func (d Destination) NetAddr() string { func (d Destination) NetAddr() string {
return d.Address.String() + ":" + d.Port.String() return d.Address.String() + ":" + d.Port.String()
} }
// String returns the strings form of this Destination.
func (d Destination) String() string { func (d Destination) String() string {
return d.Network.URLPrefix() + ":" + d.NetAddr() return d.Network.URLPrefix() + ":" + d.NetAddr()
} }
// IsValid returns true if this Destination is valid.
func (d Destination) IsValid() bool { func (d Destination) IsValid() bool {
return d.Network != Network_Unknown return d.Network != Network_Unknown
} }

View File

@ -4,23 +4,24 @@ import (
"testing" "testing"
. "v2ray.com/core/common/net" . "v2ray.com/core/common/net"
"v2ray.com/core/testing/assert" . "v2ray.com/core/common/net/testing"
. "v2ray.com/ext/assert"
) )
func TestTCPDestination(t *testing.T) { func TestTCPDestination(t *testing.T) {
assert := assert.On(t) assert := With(t)
dest := TCPDestination(IPAddress([]byte{1, 2, 3, 4}), 80) dest := TCPDestination(IPAddress([]byte{1, 2, 3, 4}), 80)
assert.Destination(dest).IsTCP() assert(dest, IsTCP)
assert.Destination(dest).IsNotUDP() assert(dest, Not(IsUDP))
assert.Destination(dest).EqualsString("tcp:1.2.3.4:80") assert(dest.String(), Equals, "tcp:1.2.3.4:80")
} }
func TestUDPDestination(t *testing.T) { func TestUDPDestination(t *testing.T) {
assert := assert.On(t) assert := With(t)
dest := UDPDestination(IPAddress([]byte{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x88, 0x88}), 53) dest := UDPDestination(IPAddress([]byte{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x88, 0x88}), 53)
assert.Destination(dest).IsNotTCP() assert(dest, Not(IsTCP))
assert.Destination(dest).IsUDP() assert(dest, IsUDP)
assert.Destination(dest).EqualsString("udp:[2001:4860:4860::8888]:53") assert(dest.String(), Equals, "udp:[2001:4860:4860::8888]:53")
} }

View File

@ -44,6 +44,7 @@ func (n *IPNetTable) Add(ipNet *net.IPNet) {
func (n *IPNetTable) AddIP(ip []byte, mask byte) { func (n *IPNetTable) AddIP(ip []byte, mask byte) {
k := ipToUint32(ip) k := ipToUint32(ip)
k = (k >> (32 - mask)) << (32 - mask) // normalize ip
existing, found := n.cache[k] existing, found := n.cache[k]
if !found || existing > mask { if !found || existing > mask {
n.cache[k] = mask n.cache[k] = mask

View File

@ -2,11 +2,19 @@ package net_test
import ( import (
"net" "net"
"os"
"path/filepath"
"testing" "testing"
proto "github.com/golang/protobuf/proto"
"v2ray.com/core/app/router"
"v2ray.com/core/common/platform"
"v2ray.com/ext/sysio"
"v2ray.com/core/common" "v2ray.com/core/common"
. "v2ray.com/core/common/net" . "v2ray.com/core/common/net"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func parseCIDR(str string) *net.IPNet { func parseCIDR(str string) *net.IPNet {
@ -16,7 +24,7 @@ func parseCIDR(str string) *net.IPNet {
} }
func TestIPNet(t *testing.T) { func TestIPNet(t *testing.T) {
assert := assert.On(t) assert := With(t)
ipNet := NewIPNetTable() ipNet := NewIPNetTable()
ipNet.Add(parseCIDR(("0.0.0.0/8"))) ipNet.Add(parseCIDR(("0.0.0.0/8")))
@ -32,12 +40,94 @@ func TestIPNet(t *testing.T) {
ipNet.Add(parseCIDR(("198.51.100.0/24"))) ipNet.Add(parseCIDR(("198.51.100.0/24")))
ipNet.Add(parseCIDR(("203.0.113.0/24"))) ipNet.Add(parseCIDR(("203.0.113.0/24")))
ipNet.Add(parseCIDR(("8.8.8.8/32"))) ipNet.Add(parseCIDR(("8.8.8.8/32")))
assert.Bool(ipNet.Contains(ParseIP("192.168.1.1"))).IsTrue() ipNet.AddIP(net.ParseIP("91.108.4.0"), 16)
assert.Bool(ipNet.Contains(ParseIP("192.0.0.0"))).IsTrue() assert(ipNet.Contains(ParseIP("192.168.1.1")), IsTrue)
assert.Bool(ipNet.Contains(ParseIP("192.0.1.0"))).IsFalse() assert(ipNet.Contains(ParseIP("192.0.0.0")), IsTrue)
assert.Bool(ipNet.Contains(ParseIP("0.1.0.0"))).IsTrue() assert(ipNet.Contains(ParseIP("192.0.1.0")), IsFalse)
assert.Bool(ipNet.Contains(ParseIP("1.0.0.1"))).IsFalse() assert(ipNet.Contains(ParseIP("0.1.0.0")), IsTrue)
assert.Bool(ipNet.Contains(ParseIP("8.8.8.7"))).IsFalse() assert(ipNet.Contains(ParseIP("1.0.0.1")), IsFalse)
assert.Bool(ipNet.Contains(ParseIP("8.8.8.8"))).IsTrue() assert(ipNet.Contains(ParseIP("8.8.8.7")), IsFalse)
assert.Bool(ipNet.Contains(ParseIP("2001:cdba::3257:9652"))).IsFalse() assert(ipNet.Contains(ParseIP("8.8.8.8")), IsTrue)
assert(ipNet.Contains(ParseIP("2001:cdba::3257:9652")), IsFalse)
assert(ipNet.Contains(ParseIP("91.108.255.254")), IsTrue)
}
func TestGeoIPCN(t *testing.T) {
assert := With(t)
common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "tools", "release", "config", "geoip.dat")))
ips, err := loadGeoIP("CN")
common.Must(err)
ipNet := NewIPNetTable()
for _, ip := range ips {
ipNet.AddIP(ip.Ip, byte(ip.Prefix))
}
assert(ipNet.Contains([]byte{8, 8, 8, 8}), IsFalse)
}
func loadGeoIP(country string) ([]*router.CIDR, error) {
geoipBytes, err := sysio.ReadAsset("geoip.dat")
if err != nil {
return nil, err
}
var geoipList router.GeoIPList
if err := proto.Unmarshal(geoipBytes, &geoipList); err != nil {
return nil, err
}
for _, geoip := range geoipList.Entry {
if geoip.CountryCode == country {
return geoip.Cidr, nil
}
}
panic("country not found: " + country)
}
func BenchmarkIPNetQuery(b *testing.B) {
common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "tools", "release", "config", "geoip.dat")))
ips, err := loadGeoIP("CN")
common.Must(err)
ipNet := NewIPNetTable()
for _, ip := range ips {
ipNet.AddIP(ip.Ip, byte(ip.Prefix))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
ipNet.Contains([]byte{8, 8, 8, 8})
}
}
func BenchmarkCIDRQuery(b *testing.B) {
common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "tools", "release", "config", "geoip.dat")))
ips, err := loadGeoIP("CN")
common.Must(err)
ipNet := make([]*net.IPNet, 0, 1024)
for _, ip := range ips {
if len(ip.Ip) != 4 {
continue
}
ipNet = append(ipNet, &net.IPNet{
IP: net.IP(ip.Ip),
Mask: net.CIDRMask(int(ip.Prefix), 32),
})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, n := range ipNet {
if n.Contains([]byte{8, 8, 8, 8}) {
break
}
}
}
} }

View File

@ -1,4 +1,4 @@
// Package net contains common network utilities. // Package net is a drop-in replacement to Golang's net package, with some more functionalities.
package net package net
//go:generate go run $GOPATH/src/v2ray.com/core/tools/generrorgen/main.go -pkg net -path Net //go:generate go run $GOPATH/src/v2ray.com/core/tools/generrorgen/main.go -pkg net -path Net

View File

@ -4,15 +4,15 @@ import (
"testing" "testing"
. "v2ray.com/core/common/net" . "v2ray.com/core/common/net"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func TestPortRangeContains(t *testing.T) { func TestPortRangeContains(t *testing.T) {
assert := assert.On(t) assert := With(t)
portRange := &PortRange{ portRange := &PortRange{
From: 53, From: 53,
To: 53, To: 53,
} }
assert.Bool(portRange.Contains(Port(53))).IsTrue() assert(portRange.Contains(Port(53)), IsTrue)
} }

View File

@ -2,6 +2,7 @@ package net
import "net" import "net"
// DialTCP is an injectable function. Default to net.DialTCP
var DialTCP = net.DialTCP var DialTCP = net.DialTCP
var DialUDP = net.DialUDP var DialUDP = net.DialUDP
var DialUnix = net.DialUnix var DialUnix = net.DialUnix
@ -31,6 +32,7 @@ type UDPConn = net.UDPConn
type UnixAddr = net.UnixAddr type UnixAddr = net.UnixAddr
type UnixConn = net.UnixConn type UnixConn = net.UnixConn
// IP is an alias for net.IP.
type IP = net.IP type IP = net.IP
type IPMask = net.IPMask type IPMask = net.IPMask
type IPNet = net.IPNet type IPNet = net.IPNet

View File

@ -0,0 +1,48 @@
package testing
import (
"v2ray.com/core/common/net"
"v2ray.com/ext/assert"
)
var IsIPv4 = assert.CreateMatcher(func(a net.Address) bool {
return a.Family().IsIPv4()
}, "is IPv4")
var IsIPv6 = assert.CreateMatcher(func(a net.Address) bool {
return a.Family().IsIPv6()
}, "is IPv6")
var IsIP = assert.CreateMatcher(func(a net.Address) bool {
return a.Family().IsIPv4() || a.Family().IsIPv6()
}, "is IP")
var IsTCP = assert.CreateMatcher(func(a net.Destination) bool {
return a.Network == net.Network_TCP
}, "is TCP")
var IsUDP = assert.CreateMatcher(func(a net.Destination) bool {
return a.Network == net.Network_UDP
}, "is UDP")
var IsDomain = assert.CreateMatcher(func(a net.Address) bool {
return a.Family().IsDomain()
}, "is Domain")
func init() {
assert.RegisterEqualsMatcher(func(a, b net.Address) bool {
return a == b
})
assert.RegisterEqualsMatcher(func(a, b net.Destination) bool {
return a == b
})
assert.RegisterEqualsMatcher(func(a, b net.Port) bool {
return a == b
})
assert.RegisterEqualsMatcher(func(a, b net.IP) bool {
return a.Equal(b)
})
}

View File

@ -4,6 +4,7 @@ package platform
import ( import (
"os" "os"
"path/filepath"
) )
func ExpandEnv(s string) string { func ExpandEnv(s string) string {
@ -13,3 +14,9 @@ func ExpandEnv(s string) string {
func LineSeparator() string { func LineSeparator() string {
return "\n" return "\n"
} }
func GetToolLocation(file string) string {
const name = "v2ray.location.tool"
toolPath := EnvFlag{Name: name, AltName: NormalizeEnvName(name)}.GetValue(getExecutableDir)
return filepath.Join(toolPath, file)
}

View File

@ -2,6 +2,7 @@ package platform
import ( import (
"os" "os"
"path/filepath"
"strconv" "strconv"
"strings" "strings"
) )
@ -11,7 +12,7 @@ type EnvFlag struct {
AltName string AltName string
} }
func (f EnvFlag) GetValue(defaultValue string) string { func (f EnvFlag) GetValue(defaultValue func() string) string {
if v, found := os.LookupEnv(f.Name); found { if v, found := os.LookupEnv(f.Name); found {
return v return v
} }
@ -21,13 +22,16 @@ func (f EnvFlag) GetValue(defaultValue string) string {
} }
} }
return defaultValue return defaultValue()
} }
func (f EnvFlag) GetValueAsInt(defaultValue int) int { func (f EnvFlag) GetValueAsInt(defaultValue int) int {
const PlaceHolder = "xxxxxx" useDefaultValue := false
s := f.GetValue(PlaceHolder) s := f.GetValue(func() string {
if s == PlaceHolder { useDefaultValue = true
return ""
})
if useDefaultValue {
return defaultValue return defaultValue
} }
v, err := strconv.ParseInt(s, 10, 32) v, err := strconv.ParseInt(s, 10, 32)
@ -40,3 +44,29 @@ func (f EnvFlag) GetValueAsInt(defaultValue int) int {
func NormalizeEnvName(name string) string { func NormalizeEnvName(name string) string {
return strings.Replace(strings.ToUpper(strings.TrimSpace(name)), ".", "_", -1) return strings.Replace(strings.ToUpper(strings.TrimSpace(name)), ".", "_", -1)
} }
func getExecutableDir() string {
exec, err := os.Executable()
if err != nil {
return ""
}
return filepath.Dir(exec)
}
func getExecuableSubDir(dir string) func() string {
return func() string {
return filepath.Join(getExecutableDir(), dir)
}
}
func GetAssetLocation(file string) string {
const name = "v2ray.location.asset"
assetPath := EnvFlag{Name: name, AltName: NormalizeEnvName(name)}.GetValue(getExecutableDir)
return filepath.Join(assetPath, file)
}
func GetPluginDirectory() string {
const name = "v2ray.location.plugin"
pluginDir := EnvFlag{Name: name, AltName: NormalizeEnvName(name)}.GetValue(getExecuableSubDir("plugins"))
return pluginDir
}

View File

@ -1,14 +1,16 @@
package platform_test package platform_test
import ( import (
"os"
"path/filepath"
"testing" "testing"
. "v2ray.com/core/common/platform" . "v2ray.com/core/common/platform"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func TestNormalizeEnvName(t *testing.T) { func TestNormalizeEnvName(t *testing.T) {
assert := assert.On(t) assert := With(t)
cases := []struct { cases := []struct {
input string input string
@ -28,14 +30,27 @@ func TestNormalizeEnvName(t *testing.T) {
}, },
} }
for _, test := range cases { for _, test := range cases {
assert.String(NormalizeEnvName(test.input)).Equals(test.output) assert(NormalizeEnvName(test.input), Equals, test.output)
} }
} }
func TestEnvFlag(t *testing.T) { func TestEnvFlag(t *testing.T) {
assert := assert.On(t) assert := With(t)
assert.Int(EnvFlag{ assert(EnvFlag{
Name: "xxxxx.y", Name: "xxxxx.y",
}.GetValueAsInt(10)).Equals(10) }.GetValueAsInt(10), Equals, 10)
}
func TestGetAssetLocation(t *testing.T) {
assert := With(t)
exec, err := os.Executable()
assert(err, IsNil)
loc := GetAssetLocation("t")
assert(filepath.Dir(loc), Equals, filepath.Dir(exec))
os.Setenv("v2ray.location.asset", "/v2ray")
assert(GetAssetLocation("t"), Equals, "/v2ray/t")
} }

View File

@ -2,6 +2,8 @@
package platform package platform
import "path/filepath"
func ExpandEnv(s string) string { func ExpandEnv(s string) string {
// TODO // TODO
return s return s
@ -10,3 +12,9 @@ func ExpandEnv(s string) string {
func LineSeparator() string { func LineSeparator() string {
return "\r\n" return "\r\n"
} }
func GetToolLocation(file string) string {
const name = "v2ray.location.tool"
toolPath := EnvFlag{Name: name, AltName: NormalizeEnvName(name)}.GetValue(getExecutableDir)
return filepath.Join(toolPath, file+".exe")
}

View File

@ -3,6 +3,7 @@ package protocol
import ( import (
"runtime" "runtime"
"v2ray.com/core/common/bitmask"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/uuid" "v2ray.com/core/common/uuid"
) )
@ -24,35 +25,20 @@ func (c RequestCommand) TransferType() TransferType {
return TransferTypePacket return TransferTypePacket
} }
// RequestOption is the options of a request.
type RequestOption byte
const ( const (
// RequestOptionChunkStream indicates request payload is chunked. Each chunk consists of length, authentication and payload. // RequestOptionChunkStream indicates request payload is chunked. Each chunk consists of length, authentication and payload.
RequestOptionChunkStream = RequestOption(0x01) RequestOptionChunkStream bitmask.Byte = 0x01
// RequestOptionConnectionReuse indicates client side expects to reuse the connection. // RequestOptionConnectionReuse indicates client side expects to reuse the connection.
RequestOptionConnectionReuse = RequestOption(0x02) RequestOptionConnectionReuse bitmask.Byte = 0x02
RequestOptionChunkMasking = RequestOption(0x04) RequestOptionChunkMasking bitmask.Byte = 0x04
) )
func (v RequestOption) Has(option RequestOption) bool {
return (v & option) == option
}
func (v *RequestOption) Set(option RequestOption) {
*v = (*v | option)
}
func (v *RequestOption) Clear(option RequestOption) {
*v = (*v & (^option))
}
type Security byte type Security byte
func (v Security) Is(t SecurityType) bool { func (s Security) Is(t SecurityType) bool {
return v == Security(t) return s == Security(t)
} }
func NormSecurity(s Security) Security { func NormSecurity(s Security) Security {
@ -65,42 +51,28 @@ func NormSecurity(s Security) Security {
type RequestHeader struct { type RequestHeader struct {
Version byte Version byte
Command RequestCommand Command RequestCommand
Option RequestOption Option bitmask.Byte
Security Security Security Security
Port net.Port Port net.Port
Address net.Address Address net.Address
User *User User *User
} }
func (v *RequestHeader) Destination() net.Destination { func (h *RequestHeader) Destination() net.Destination {
if v.Command == RequestCommandUDP { if h.Command == RequestCommandUDP {
return net.UDPDestination(v.Address, v.Port) return net.UDPDestination(h.Address, h.Port)
} }
return net.TCPDestination(v.Address, v.Port) return net.TCPDestination(h.Address, h.Port)
} }
type ResponseOption byte
const ( const (
ResponseOptionConnectionReuse = ResponseOption(0x01) ResponseOptionConnectionReuse bitmask.Byte = 0x01
) )
func (v *ResponseOption) Set(option ResponseOption) {
*v = (*v | option)
}
func (v ResponseOption) Has(option ResponseOption) bool {
return (v & option) == option
}
func (v *ResponseOption) Clear(option ResponseOption) {
*v = (*v & (^option))
}
type ResponseCommand interface{} type ResponseCommand interface{}
type ResponseHeader struct { type ResponseHeader struct {
Option ResponseOption Option bitmask.Byte
Command ResponseCommand Command ResponseCommand
} }
@ -108,20 +80,21 @@ type CommandSwitchAccount struct {
Host net.Address Host net.Address
Port net.Port Port net.Port
ID *uuid.UUID ID *uuid.UUID
AlterIds uint16
Level uint32 Level uint32
AlterIds uint16
ValidMin byte ValidMin byte
} }
func (v *SecurityConfig) AsSecurity() Security { func (sc *SecurityConfig) AsSecurity() Security {
if v == nil { if sc == nil || sc.Type == SecurityType_AUTO {
return Security(SecurityType_LEGACY)
}
if v.Type == SecurityType_AUTO {
if runtime.GOARCH == "amd64" || runtime.GOARCH == "s390x" { if runtime.GOARCH == "amd64" || runtime.GOARCH == "s390x" {
return Security(SecurityType_AES128_GCM) return Security(SecurityType_AES128_GCM)
} }
return Security(SecurityType_CHACHA20_POLY1305) return Security(SecurityType_CHACHA20_POLY1305)
} }
return NormSecurity(Security(v.Type)) return NormSecurity(Security(sc.Type))
}
func IsDomainTooLong(domain string) bool {
return len(domain) > 256
} }

View File

@ -1,34 +0,0 @@
package protocol_test
import (
"testing"
. "v2ray.com/core/common/protocol"
"v2ray.com/core/testing/assert"
)
func TestRequestOptionSet(t *testing.T) {
assert := assert.On(t)
var option RequestOption
assert.Bool(option.Has(RequestOptionChunkStream)).IsFalse()
option.Set(RequestOptionChunkStream)
assert.Bool(option.Has(RequestOptionChunkStream)).IsTrue()
option.Set(RequestOptionChunkMasking)
assert.Bool(option.Has(RequestOptionChunkMasking)).IsTrue()
assert.Bool(option.Has(RequestOptionChunkStream)).IsTrue()
}
func TestRequestOptionClear(t *testing.T) {
assert := assert.On(t)
var option RequestOption
option.Set(RequestOptionChunkStream)
option.Set(RequestOptionChunkMasking)
option.Clear(RequestOptionChunkStream)
assert.Bool(option.Has(RequestOptionChunkStream)).IsFalse()
assert.Bool(option.Has(RequestOptionChunkMasking)).IsTrue()
}

View File

@ -6,12 +6,12 @@ import (
"v2ray.com/core/common/predicate" "v2ray.com/core/common/predicate"
. "v2ray.com/core/common/protocol" . "v2ray.com/core/common/protocol"
"v2ray.com/core/common/uuid" "v2ray.com/core/common/uuid"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func TestCmdKey(t *testing.T) { func TestCmdKey(t *testing.T) {
assert := assert.On(t) assert := With(t)
id := NewID(uuid.New()) id := NewID(uuid.New())
assert.Bool(predicate.BytesAll(id.CmdKey(), 0)).IsFalse() assert(predicate.BytesAll(id.CmdKey(), 0), IsFalse)
} }

View File

@ -6,3 +6,11 @@ const (
TransferTypeStream TransferType = 0 TransferTypeStream TransferType = 0
TransferTypePacket TransferType = 1 TransferTypePacket TransferType = 1
) )
type AddressType byte
const (
AddressTypeIPv4 AddressType = 1
AddressTypeDomain AddressType = 2
AddressTypeIPv6 AddressType = 3
)

View File

@ -6,30 +6,30 @@ import (
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
. "v2ray.com/core/common/protocol" . "v2ray.com/core/common/protocol"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func TestServerList(t *testing.T) { func TestServerList(t *testing.T) {
assert := assert.On(t) assert := With(t)
list := NewServerList() list := NewServerList()
list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(1)), AlwaysValid())) list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(1)), AlwaysValid()))
assert.Uint32(list.Size()).Equals(1) assert(list.Size(), Equals, uint32(1))
list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(2)), BeforeTime(time.Now().Add(time.Second)))) list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(2)), BeforeTime(time.Now().Add(time.Second))))
assert.Uint32(list.Size()).Equals(2) assert(list.Size(), Equals, uint32(2))
server := list.GetServer(1) server := list.GetServer(1)
assert.Port(server.Destination().Port).Equals(2) assert(server.Destination().Port, Equals, net.Port(2))
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
server = list.GetServer(1) server = list.GetServer(1)
assert.Pointer(server).IsNil() assert(server, IsNil)
server = list.GetServer(0) server = list.GetServer(0)
assert.Port(server.Destination().Port).Equals(1) assert(server.Destination().Port, Equals, net.Port(1))
} }
func TestServerPicker(t *testing.T) { func TestServerPicker(t *testing.T) {
assert := assert.On(t) assert := With(t)
list := NewServerList() list := NewServerList()
list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(1)), AlwaysValid())) list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(1)), AlwaysValid()))
@ -38,17 +38,17 @@ func TestServerPicker(t *testing.T) {
picker := NewRoundRobinServerPicker(list) picker := NewRoundRobinServerPicker(list)
server := picker.PickServer() server := picker.PickServer()
assert.Port(server.Destination().Port).Equals(1) assert(server.Destination().Port, Equals, net.Port(1))
server = picker.PickServer() server = picker.PickServer()
assert.Port(server.Destination().Port).Equals(2) assert(server.Destination().Port, Equals, net.Port(2))
server = picker.PickServer() server = picker.PickServer()
assert.Port(server.Destination().Port).Equals(3) assert(server.Destination().Port, Equals, net.Port(3))
server = picker.PickServer() server = picker.PickServer()
assert.Port(server.Destination().Port).Equals(1) assert(server.Destination().Port, Equals, net.Port(1))
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
server = picker.PickServer() server = picker.PickServer()
assert.Port(server.Destination().Port).Equals(1) assert(server.Destination().Port, Equals, net.Port(1))
server = picker.PickServer() server = picker.PickServer()
assert.Port(server.Destination().Port).Equals(1) assert(server.Destination().Port, Equals, net.Port(1))
} }

View File

@ -5,27 +5,27 @@ import (
"time" "time"
. "v2ray.com/core/common/protocol" . "v2ray.com/core/common/protocol"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func TestAlwaysValidStrategy(t *testing.T) { func TestAlwaysValidStrategy(t *testing.T) {
assert := assert.On(t) assert := With(t)
strategy := AlwaysValid() strategy := AlwaysValid()
assert.Bool(strategy.IsValid()).IsTrue() assert(strategy.IsValid(), IsTrue)
strategy.Invalidate() strategy.Invalidate()
assert.Bool(strategy.IsValid()).IsTrue() assert(strategy.IsValid(), IsTrue)
} }
func TestTimeoutValidStrategy(t *testing.T) { func TestTimeoutValidStrategy(t *testing.T) {
assert := assert.On(t) assert := With(t)
strategy := BeforeTime(time.Now().Add(2 * time.Second)) strategy := BeforeTime(time.Now().Add(2 * time.Second))
assert.Bool(strategy.IsValid()).IsTrue() assert(strategy.IsValid(), IsTrue)
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
assert.Bool(strategy.IsValid()).IsFalse() assert(strategy.IsValid(), IsFalse)
strategy = BeforeTime(time.Now().Add(2 * time.Second)) strategy = BeforeTime(time.Now().Add(2 * time.Second))
strategy.Invalidate() strategy.Invalidate()
assert.Bool(strategy.IsValid()).IsFalse() assert(strategy.IsValid(), IsFalse)
} }

View File

@ -5,11 +5,11 @@ import (
"time" "time"
. "v2ray.com/core/common/protocol" . "v2ray.com/core/common/protocol"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func TestGenerateRandomInt64InRange(t *testing.T) { func TestGenerateRandomInt64InRange(t *testing.T) {
assert := assert.On(t) assert := With(t)
base := time.Now().Unix() base := time.Now().Unix()
delta := 100 delta := 100
@ -17,7 +17,7 @@ func TestGenerateRandomInt64InRange(t *testing.T) {
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
val := int64(generator()) val := int64(generator())
assert.Int64(val).AtMost(base + int64(delta)) assert(val, AtMost, base+int64(delta))
assert.Int64(val).AtLeast(base - int64(delta)) assert(val, AtLeast, base-int64(delta))
} }
} }

View File

@ -1,7 +1,5 @@
package protocol package protocol
import "time"
func (u *User) GetTypedAccount() (Account, error) { func (u *User) GetTypedAccount() (Account, error) {
if u.GetAccount() == nil { if u.GetAccount() == nil {
return nil, newError("Account missing").AtWarning() return nil, newError("Account missing").AtWarning()
@ -19,20 +17,3 @@ func (u *User) GetTypedAccount() (Account, error) {
} }
return nil, newError("Unknown account type: ", u.Account.Type) return nil, newError("Unknown account type: ", u.Account.Type)
} }
func (u *User) GetSettings() UserSettings {
settings := UserSettings{}
switch u.Level {
case 0:
settings.PayloadTimeout = time.Second * 30
case 1:
settings.PayloadTimeout = time.Minute * 2
default:
settings.PayloadTimeout = time.Minute * 5
}
return settings
}
type UserSettings struct {
PayloadTimeout time.Duration
}

View File

@ -6,7 +6,7 @@ import (
"v2ray.com/core/common/errors" "v2ray.com/core/common/errors"
. "v2ray.com/core/common/retry" . "v2ray.com/core/common/retry"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
var ( var (
@ -14,7 +14,7 @@ var (
) )
func TestNoRetry(t *testing.T) { func TestNoRetry(t *testing.T) {
assert := assert.On(t) assert := With(t)
startTime := time.Now().Unix() startTime := time.Now().Unix()
err := Timed(10, 100000).On(func() error { err := Timed(10, 100000).On(func() error {
@ -22,12 +22,12 @@ func TestNoRetry(t *testing.T) {
}) })
endTime := time.Now().Unix() endTime := time.Now().Unix()
assert.Error(err).IsNil() assert(err, IsNil)
assert.Int64(endTime - startTime).AtLeast(0) assert(endTime-startTime, AtLeast, int64(0))
} }
func TestRetryOnce(t *testing.T) { func TestRetryOnce(t *testing.T) {
assert := assert.On(t) assert := With(t)
startTime := time.Now() startTime := time.Now()
called := 0 called := 0
@ -40,12 +40,12 @@ func TestRetryOnce(t *testing.T) {
}) })
duration := time.Since(startTime) duration := time.Since(startTime)
assert.Error(err).IsNil() assert(err, IsNil)
assert.Int64(int64(duration / time.Millisecond)).AtLeast(900) assert(int64(duration/time.Millisecond), AtLeast, int64(900))
} }
func TestRetryMultiple(t *testing.T) { func TestRetryMultiple(t *testing.T) {
assert := assert.On(t) assert := With(t)
startTime := time.Now() startTime := time.Now()
called := 0 called := 0
@ -58,12 +58,12 @@ func TestRetryMultiple(t *testing.T) {
}) })
duration := time.Since(startTime) duration := time.Since(startTime)
assert.Error(err).IsNil() assert(err, IsNil)
assert.Int64(int64(duration / time.Millisecond)).AtLeast(4900) assert(int64(duration/time.Millisecond), AtLeast, int64(4900))
} }
func TestRetryExhausted(t *testing.T) { func TestRetryExhausted(t *testing.T) {
assert := assert.On(t) assert := With(t)
startTime := time.Now() startTime := time.Now()
called := 0 called := 0
@ -73,12 +73,12 @@ func TestRetryExhausted(t *testing.T) {
}) })
duration := time.Since(startTime) duration := time.Since(startTime)
assert.Error(errors.Cause(err)).Equals(ErrRetryFailed) assert(errors.Cause(err), Equals, ErrRetryFailed)
assert.Int64(int64(duration / time.Millisecond)).AtLeast(1900) assert(int64(duration/time.Millisecond), AtLeast, int64(1900))
} }
func TestExponentialBackoff(t *testing.T) { func TestExponentialBackoff(t *testing.T) {
assert := assert.On(t) assert := With(t)
startTime := time.Now() startTime := time.Now()
called := 0 called := 0
@ -88,6 +88,6 @@ func TestExponentialBackoff(t *testing.T) {
}) })
duration := time.Since(startTime) duration := time.Since(startTime)
assert.Error(errors.Cause(err)).Equals(ErrRetryFailed) assert(errors.Cause(err), Equals, ErrRetryFailed)
assert.Int64(int64(duration / time.Millisecond)).AtLeast(4000) assert(int64(duration/time.Millisecond), AtLeast, int64(4000))
} }

View File

@ -4,11 +4,11 @@ import (
"testing" "testing"
. "v2ray.com/core/common/serial" . "v2ray.com/core/common/serial"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func TestBytesToHex(t *testing.T) { func TestBytesToHex(t *testing.T) {
assert := assert.On(t) assert := With(t)
cases := []struct { cases := []struct {
input []byte input []byte
@ -21,15 +21,15 @@ func TestBytesToHex(t *testing.T) {
} }
for _, test := range cases { for _, test := range cases {
assert.String(test.output).Equals(BytesToHexString(test.input)) assert(test.output, Equals, BytesToHexString(test.input))
} }
} }
func TestInt64(t *testing.T) { func TestInt64(t *testing.T) {
assert := assert.On(t) assert := With(t)
x := int64(375134875348) x := int64(375134875348)
b := Int64ToBytes(x, []byte{}) b := Int64ToBytes(x, []byte{})
v := BytesToInt64(b) v := BytesToInt64(b)
assert.Int64(x).Equals(v) assert(x, Equals, v)
} }

View File

@ -6,15 +6,15 @@ import (
"v2ray.com/core/common" "v2ray.com/core/common"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
. "v2ray.com/core/common/serial" . "v2ray.com/core/common/serial"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func TestUint32(t *testing.T) { func TestUint32(t *testing.T) {
assert := assert.On(t) assert := With(t)
x := uint32(458634234) x := uint32(458634234)
s1 := Uint32ToBytes(x, []byte{}) s1 := Uint32ToBytes(x, []byte{})
s2 := buf.New() s2 := buf.New()
common.Must(s2.AppendSupplier(WriteUint32(x))) common.Must(s2.AppendSupplier(WriteUint32(x)))
assert.Bytes(s1).Equals(s2.Bytes()) assert(s1, Equals, s2.Bytes())
} }

View File

@ -4,13 +4,13 @@ import (
"testing" "testing"
. "v2ray.com/core/common/serial" . "v2ray.com/core/common/serial"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func TestGetInstance(t *testing.T) { func TestGetInstance(t *testing.T) {
assert := assert.On(t) assert := With(t)
p, err := GetInstance("") p, err := GetInstance("")
assert.Pointer(p).IsNil() assert(p, IsNil)
assert.Error(err).IsNotNil() assert(err, IsNotNil)
} }

View File

@ -6,11 +6,11 @@ import (
"testing" "testing"
. "v2ray.com/core/common/signal" . "v2ray.com/core/common/signal"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func TestErrorOrFinish2_Error(t *testing.T) { func TestErrorOrFinish2_Error(t *testing.T) {
assert := assert.On(t) assert := With(t)
c1 := make(chan error, 1) c1 := make(chan error, 1)
c2 := make(chan error, 2) c2 := make(chan error, 2)
@ -22,11 +22,11 @@ func TestErrorOrFinish2_Error(t *testing.T) {
c1 <- errors.New("test") c1 <- errors.New("test")
err := <-c err := <-c
assert.String(err.Error()).Equals("test") assert(err.Error(), Equals, "test")
} }
func TestErrorOrFinish2_Error2(t *testing.T) { func TestErrorOrFinish2_Error2(t *testing.T) {
assert := assert.On(t) assert := With(t)
c1 := make(chan error, 1) c1 := make(chan error, 1)
c2 := make(chan error, 2) c2 := make(chan error, 2)
@ -38,11 +38,11 @@ func TestErrorOrFinish2_Error2(t *testing.T) {
c2 <- errors.New("test") c2 <- errors.New("test")
err := <-c err := <-c
assert.String(err.Error()).Equals("test") assert(err.Error(), Equals, "test")
} }
func TestErrorOrFinish2_NoneError(t *testing.T) { func TestErrorOrFinish2_NoneError(t *testing.T) {
assert := assert.On(t) assert := With(t)
c1 := make(chan error, 1) c1 := make(chan error, 1)
c2 := make(chan error, 2) c2 := make(chan error, 2)
@ -61,11 +61,11 @@ func TestErrorOrFinish2_NoneError(t *testing.T) {
close(c2) close(c2)
err := <-c err := <-c
assert.Error(err).IsNil() assert(err, IsNil)
} }
func TestErrorOrFinish2_NoneError2(t *testing.T) { func TestErrorOrFinish2_NoneError2(t *testing.T) {
assert := assert.On(t) assert := With(t)
c1 := make(chan error, 1) c1 := make(chan error, 1)
c2 := make(chan error, 2) c2 := make(chan error, 2)
@ -84,5 +84,5 @@ func TestErrorOrFinish2_NoneError2(t *testing.T) {
close(c1) close(c1)
err := <-c err := <-c
assert.Error(err).IsNil() assert(err, IsNil)
} }

View File

@ -12,8 +12,6 @@ type ActivityUpdater interface {
type ActivityTimer struct { type ActivityTimer struct {
updated chan bool updated chan bool
timeout chan time.Duration timeout chan time.Duration
ctx context.Context
cancel context.CancelFunc
} }
func (t *ActivityTimer) Update() { func (t *ActivityTimer) Update() {
@ -27,7 +25,7 @@ func (t *ActivityTimer) SetTimeout(timeout time.Duration) {
t.timeout <- timeout t.timeout <- timeout
} }
func (t *ActivityTimer) run() { func (t *ActivityTimer) run(ctx context.Context, cancel context.CancelFunc) {
ticker := time.NewTicker(<-t.timeout) ticker := time.NewTicker(<-t.timeout)
defer func() { defer func() {
ticker.Stop() ticker.Stop()
@ -36,32 +34,35 @@ func (t *ActivityTimer) run() {
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
case <-t.ctx.Done(): case <-ctx.Done():
return return
case timeout := <-t.timeout: case timeout := <-t.timeout:
if timeout == 0 {
cancel()
return
}
ticker.Stop() ticker.Stop()
ticker = time.NewTicker(timeout) ticker = time.NewTicker(timeout)
continue
} }
select { select {
case <-t.updated: case <-t.updated:
// Updated keep waiting. // Updated keep waiting.
default: default:
t.cancel() cancel()
return return
} }
} }
} }
func CancelAfterInactivity(ctx context.Context, timeout time.Duration) (context.Context, *ActivityTimer) { func CancelAfterInactivity(ctx context.Context, cancel context.CancelFunc, timeout time.Duration) *ActivityTimer {
ctx, cancel := context.WithCancel(ctx)
timer := &ActivityTimer{ timer := &ActivityTimer{
ctx: ctx,
cancel: cancel,
timeout: make(chan time.Duration, 1), timeout: make(chan time.Duration, 1),
updated: make(chan bool, 1), updated: make(chan bool, 1),
} }
timer.timeout <- timeout timer.timeout <- timeout
go timer.run() go timer.run(ctx, cancel)
return ctx, timer return timer
} }

View File

@ -7,26 +7,28 @@ import (
"time" "time"
. "v2ray.com/core/common/signal" . "v2ray.com/core/common/signal"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func TestActivityTimer(t *testing.T) { func TestActivityTimer(t *testing.T) {
assert := assert.On(t) assert := With(t)
ctx, timer := CancelAfterInactivity(context.Background(), time.Second*5) ctx, cancel := context.WithCancel(context.Background())
timer := CancelAfterInactivity(ctx, cancel, time.Second*5)
time.Sleep(time.Second * 6) time.Sleep(time.Second * 6)
assert.Error(ctx.Err()).IsNotNil() assert(ctx.Err(), IsNotNil)
runtime.KeepAlive(timer) runtime.KeepAlive(timer)
} }
func TestActivityTimerUpdate(t *testing.T) { func TestActivityTimerUpdate(t *testing.T) {
assert := assert.On(t) assert := With(t)
ctx, timer := CancelAfterInactivity(context.Background(), time.Second*10) ctx, cancel := context.WithCancel(context.Background())
timer := CancelAfterInactivity(ctx, cancel, time.Second*10)
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
assert.Error(ctx.Err()).IsNil() assert(ctx.Err(), IsNil)
timer.SetTimeout(time.Second * 1) timer.SetTimeout(time.Second * 1)
time.Sleep(time.Second * 2) time.Sleep(time.Second * 2)
assert.Error(ctx.Err()).IsNotNil() assert(ctx.Err(), IsNotNil)
runtime.KeepAlive(timer) runtime.KeepAlive(timer)
} }

View File

@ -4,74 +4,74 @@ import (
"testing" "testing"
. "v2ray.com/core/common/uuid" . "v2ray.com/core/common/uuid"
"v2ray.com/core/testing/assert" . "v2ray.com/ext/assert"
) )
func TestParseBytes(t *testing.T) { func TestParseBytes(t *testing.T) {
assert := assert.On(t) assert := With(t)
str := "2418d087-648d-4990-86e8-19dca1d006d3" str := "2418d087-648d-4990-86e8-19dca1d006d3"
bytes := []byte{0x24, 0x18, 0xd0, 0x87, 0x64, 0x8d, 0x49, 0x90, 0x86, 0xe8, 0x19, 0xdc, 0xa1, 0xd0, 0x06, 0xd3} bytes := []byte{0x24, 0x18, 0xd0, 0x87, 0x64, 0x8d, 0x49, 0x90, 0x86, 0xe8, 0x19, 0xdc, 0xa1, 0xd0, 0x06, 0xd3}
uuid, err := ParseBytes(bytes) uuid, err := ParseBytes(bytes)
assert.Error(err).IsNil() assert(err, IsNil)
assert.String(uuid.String()).Equals(str) assert(uuid.String(), Equals, str)
_, err = ParseBytes([]byte{1, 3, 2, 4}) _, err = ParseBytes([]byte{1, 3, 2, 4})
assert.Error(err).IsNotNil() assert(err, IsNotNil)
} }
func TestParseString(t *testing.T) { func TestParseString(t *testing.T) {
assert := assert.On(t) assert := With(t)
str := "2418d087-648d-4990-86e8-19dca1d006d3" str := "2418d087-648d-4990-86e8-19dca1d006d3"
expectedBytes := []byte{0x24, 0x18, 0xd0, 0x87, 0x64, 0x8d, 0x49, 0x90, 0x86, 0xe8, 0x19, 0xdc, 0xa1, 0xd0, 0x06, 0xd3} expectedBytes := []byte{0x24, 0x18, 0xd0, 0x87, 0x64, 0x8d, 0x49, 0x90, 0x86, 0xe8, 0x19, 0xdc, 0xa1, 0xd0, 0x06, 0xd3}
uuid, err := ParseString(str) uuid, err := ParseString(str)
assert.Error(err).IsNil() assert(err, IsNil)
assert.Bytes(uuid.Bytes()).Equals(expectedBytes) assert(uuid.Bytes(), Equals, expectedBytes)
uuid, err = ParseString("2418d087") uuid, err = ParseString("2418d087")
assert.Error(err).IsNotNil() assert(err, IsNotNil)
uuid, err = ParseString("2418d087-648k-4990-86e8-19dca1d006d3") uuid, err = ParseString("2418d087-648k-4990-86e8-19dca1d006d3")
assert.Error(err).IsNotNil() assert(err, IsNotNil)
} }
func TestNewUUID(t *testing.T) { func TestNewUUID(t *testing.T) {
assert := assert.On(t) assert := With(t)
uuid := New() uuid := New()
uuid2, err := ParseString(uuid.String()) uuid2, err := ParseString(uuid.String())
assert.Error(err).IsNil() assert(err, IsNil)
assert.String(uuid.String()).Equals(uuid2.String()) assert(uuid.String(), Equals, uuid2.String())
assert.Bytes(uuid.Bytes()).Equals(uuid2.Bytes()) assert(uuid.Bytes(), Equals, uuid2.Bytes())
} }
func TestRandom(t *testing.T) { func TestRandom(t *testing.T) {
assert := assert.On(t) assert := With(t)
uuid := New() uuid := New()
uuid2 := New() uuid2 := New()
assert.String(uuid.String()).NotEquals(uuid2.String()) assert(uuid.String(), NotEquals, uuid2.String())
assert.Bytes(uuid.Bytes()).NotEquals(uuid2.Bytes()) assert(uuid.Bytes(), NotEquals, uuid2.Bytes())
} }
func TestEquals(t *testing.T) { func TestEquals(t *testing.T) {
assert := assert.On(t) assert := With(t)
var uuid *UUID = nil var uuid *UUID = nil
var uuid2 *UUID = nil var uuid2 *UUID = nil
assert.Bool(uuid.Equals(uuid2)).IsTrue() assert(uuid.Equals(uuid2), IsTrue)
assert.Bool(uuid.Equals(New())).IsFalse() assert(uuid.Equals(New()), IsFalse)
} }
func TestNext(t *testing.T) { func TestNext(t *testing.T) {
assert := assert.On(t) assert := With(t)
uuid := New() uuid := New()
uuid2 := uuid.Next() uuid2 := uuid.Next()
assert.Bool(uuid.Equals(uuid2)).IsFalse() assert(uuid.Equals(uuid2), IsFalse)
} }

View File

@ -18,9 +18,9 @@ import (
) )
var ( var (
version = "2.41" version = "3.1"
build = "Custom" build = "Custom"
codename = "One for all" codename = "die Commanderin"
intro = "An unified platform for anti-censorship." intro = "An unified platform for anti-censorship."
) )

View File

@ -2,10 +2,11 @@ package core
import ( import (
"io" "io"
"io/ioutil"
"v2ray.com/core/common"
"v2ray.com/core/common/buf"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"v2ray.com/core/common"
) )
// ConfigLoader is an utility to load V2Ray config from external source. // ConfigLoader is an utility to load V2Ray config from external source.
@ -30,7 +31,10 @@ func LoadConfig(format ConfigFormat, input io.Reader) (*Config, error) {
func loadProtobufConfig(input io.Reader) (*Config, error) { func loadProtobufConfig(input io.Reader) (*Config, error) {
config := new(Config) config := new(Config)
data, _ := ioutil.ReadAll(input) data, err := buf.ReadAllToBytes(input)
if err != nil {
return nil, err
}
if err := proto.Unmarshal(data, config); err != nil { if err := proto.Unmarshal(data, config); err != nil {
return nil, err return nil, err
} }

View File

@ -1,5 +1,47 @@
// +build json
package main package main
import _ "v2ray.com/core/tools/conf" import (
"io"
"os"
"os/exec"
"v2ray.com/core"
"v2ray.com/core/common/platform"
)
func jsonToProto(input io.Reader) (*core.Config, error) {
v2ctl := platform.GetToolLocation("v2ctl")
_, err := os.Stat(v2ctl)
if err != nil {
return nil, err
}
cmd := exec.Command(v2ctl, "config")
cmd.Stdin = input
cmd.Stderr = os.Stderr
stdoutReader, err := cmd.StdoutPipe()
if err != nil {
return nil, err
}
defer stdoutReader.Close()
if err := cmd.Start(); err != nil {
return nil, err
}
config, err := core.LoadConfig(core.ConfigFormat_Protobuf, stdoutReader)
cmd.Wait()
return config, err
}
func init() {
core.RegisterConfigLoader(core.ConfigFormat_JSON, func(input io.Reader) (*core.Config, error) {
config, err := jsonToProto(input)
if err != nil {
return nil, newError("failed to execute v2ctl to convert config file.").Base(err)
}
return config, nil
})
}

View File

@ -4,6 +4,7 @@ import (
// The following are necessary as they register handlers in their init functions. // The following are necessary as they register handlers in their init functions.
_ "v2ray.com/core/app/dispatcher/impl" _ "v2ray.com/core/app/dispatcher/impl"
_ "v2ray.com/core/app/dns/server" _ "v2ray.com/core/app/dns/server"
_ "v2ray.com/core/app/policy/manager"
_ "v2ray.com/core/app/proxyman/inbound" _ "v2ray.com/core/app/proxyman/inbound"
_ "v2ray.com/core/app/proxyman/outbound" _ "v2ray.com/core/app/proxyman/outbound"
_ "v2ray.com/core/app/router" _ "v2ray.com/core/app/router"

View File

@ -22,6 +22,7 @@ var (
version = flag.Bool("version", false, "Show current version of V2Ray.") version = flag.Bool("version", false, "Show current version of V2Ray.")
test = flag.Bool("test", false, "Test config file only, without launching V2Ray server.") test = flag.Bool("test", false, "Test config file only, without launching V2Ray server.")
format = flag.String("format", "json", "Format of input file.") format = flag.String("format", "json", "Format of input file.")
plugin = flag.Bool("plugin", false, "True to load plugins.")
) )
func init() { func init() {
@ -67,7 +68,7 @@ func startV2Ray() (core.Server, error) {
server, err := core.New(config) server, err := core.New(config)
if err != nil { if err != nil {
return nil, newError("failed to create initialize").Base(err) return nil, newError("failed to create server").Base(err)
} }
return server, nil return server, nil
@ -82,19 +83,27 @@ func main() {
return return
} }
if *plugin {
if err := core.LoadPlugins(); err != nil {
fmt.Println("Failed to load plugins:", err.Error())
os.Exit(-1)
}
}
server, err := startV2Ray() server, err := startV2Ray()
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
return os.Exit(-1)
} }
if *test { if *test {
fmt.Println("Configuration OK.") fmt.Println("Configuration OK.")
return os.Exit(0)
} }
if err := server.Start(); err != nil { if err := server.Start(); err != nil {
fmt.Println("Failed to start", err) fmt.Println("Failed to start", err)
os.Exit(-1)
} }
osSignals := make(chan os.Signal, 1) osSignals := make(chan os.Signal, 1)

Some files were not shown because too many files have changed in this diff Show More