1
0
mirror of https://github.com/v2fly/v2ray-core.git synced 2024-12-22 10:08:15 -05:00

Merge pull request #2 from v2ray/master

Update
This commit is contained in:
sunshineplan 2017-12-02 16:58:06 +08:00 committed by GitHub
commit d1fa98b60b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
225 changed files with 5693 additions and 4239 deletions

46
.github/CODE_OF_CONDUCT.md vendored Normal file
View File

@ -0,0 +1,46 @@
# Contributor Covenant Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful.
## Scope
This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at love@v2ray.com. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version]
[homepage]: http://contributor-covenant.org
[version]: http://contributor-covenant.org/version/1/4/

View File

@ -1,44 +1,94 @@
提交 Issue 之前请先阅读 [Issue 指引](https://www.v2ray.com/zh_cn/chapter_01/issue.html),然后回答下面的问题,谢谢。
Please read the [instruction](https://www.v2ray.com/en/get_started/issue.html) and answer the following questions before submitting your issue. Thank you.
Please skip to the English section below if you don't write Chinese.
中文:
提交 Issue 之前请先阅读 [Issue 指引](https://www.v2ray.com/chapter_01/issue.html),然后回答下面的问题,谢谢。
除非特殊情况,请完整填写所有问题。不按模板发的 issue 将直接被关闭。
1) 你正在使用哪个版本的 V2Ray如果服务器和客户端使用了不同版本请注明
What version of V2Ray are you using (If you deploy different version on server and client, please explicitly point out)?
2) 你的使用场景是什么?比如使用 Chrome 通过 Socks/VMess 代理观看 YouTube 视频。
What's your scenario of using V2Ray? E.g., Watching YouTube videos in Chrome via Socks/VMess proxy.
3) 你看到的不正常的现象是什么?
What did you see?
3) 你看到的不正常的现象是什么请描述具体现象比如访问超时TLS 证书错误等)
4) 你期待看到的正确表现是怎样的?
What's your expectation?
5) 请附上你的配置文件(提交 Issue 前请隐藏服务器端IP地址
Please attach your configuration file (**Mask IP addresses before submit this issue**).
Server Configuration File服务器端配置文件:
5) 请附上你的配置(提交 Issue 前请隐藏服务器端IP地址
服务器端配置:
```javascript
// 在这里附上服务器端配置文件
// Please attach your server configuration file here.
```
Client Configuration File客户端配置文件):
客户端配置:
```javascript
// 在这里附上客户端配置文件
// Please attach your client configuration file here.
// 在这里附上客户端配置
```
6) 请附上出错时软件输出的日志。在 Linux 中,日志通常在 `/var/log/v2ray/error.log` 文件中。
Please attach the log file, especially the bottom lines if the file is large. Log file is usually `/var/log/v2ray/error.log` on Linux.
6) 请附上出错时软件输出的错误日志。在 Linux 中,日志通常在 `/var/log/v2ray/error.log` 文件中。
Server Log File服务器端日志:
服务器端错误日志:
```
// 在这里附上服务器端日志
// Please attach your server log here.
```
Client Log File客户端日志:
客户端错误日志:
```
// 在这里附上客户端日志
// Please attach your client log here.
```
7) 请附上访问日志。在 Linux 中,日志通常在 `/var/log/v2ray/error.log` 文件中。
```
// 在这里附上服务器端日志
```
8) 其它相关的配置文件(如 Nginx和相关日志。
请预览一下你填的内容再提交。
如果你已经填完上面的问卷,请把下面的英文部份删除,再提交 Issue。
Please remove the Chinese section above.
English:
Please read the [instruction](https://www.v2ray.com/en/get_started/issue.html) and answer the following questions before submitting your issue. Thank you.
Please answer all the questions with enough information. All issues not following this template will be closed immediately.
1) What version of V2Ray are you using (If you deploy different version on server and client, please explicitly point out)?
2) What's your scenario of using V2Ray? E.g., Watching YouTube videos in Chrome via Socks/VMess proxy.
3) What did you see? (Please describe in detail, such as timeout, fake TLS certificate etc)
4) What's your expectation?
5) Please attach your configuration file (**Mask IP addresses before submit this issue**).
Server configuration:
```javascript
// Please attach your server configuration here.
```
Client configuration:
```javascript
// Please attach your client configuration here.
```
6) Please attach error logs, especially the bottom lines if the file is large. Error log file is usually at `/var/log/v2ray/error.log` on Linux.
Server error log:
```
// Please attach your server error log here.
```
Client error log:
```
// Please attach your client error log here.
```
7) Please attach access log. Access log is usually at '/var/log/v2ray/access.log' on Linux.
```
// Please attach your server access log here.
```
Please review your issue before submitting.
8) Other configurations (such as Nginx) and logs.

63
.github/SUPPORT.md vendored Normal file
View File

@ -0,0 +1,63 @@
# V2Ray 用户支持 (User Support)
**English reader please skip to the [English section](#way-to-get-support) below**
## 获得帮助信息的途径
您可以从以下渠道获取帮助:
1. 官方网站:[v2ray.com](https://www.v2ray.com)
1. Github[Issues](https://github.com/v2ray/v2ray-core/issues)
1. Telegram[主群](https://t.me/projectv2ray)
## Github Issue 规则
1. 请按模板填写 issue
1. 配置文件内容使用格式化代码段进行修饰(见下面的解释);
1. 在提交 issue 前尝试减化配置文件,比如删除不必要 inbound / outbound 模块;
1. 在提交 issue 前尝试确定问题所在,比如将 socks 代理换成 http 再次观察问题是否能重现;
1. 配置文件必须结构完整,即除了必要的隐私信息之外,配置文件可以直接拿来运行。
**不按模板填写的 issue 将直接被关闭**
## 格式化代码段
在配置文件上下加入 Markdown 特定的修饰符,如下:
\`\`\`javascript
{
// 配置文件内容
}
\`\`\`
## Way to Get Support
You may get help in the following ways:
1. Office Site: [v2ray.com](https://www.v2ray.com)
1. Github: [Issues](https://github.com/v2ray/v2ray-core/issues)
1. Telegram: [Main Group](https://t.me/projectv2ray)
## Github Issue Rules
1. Please fill in the issue template.
1. Decorate config file with Markdown formatter (See below).
1. Try to simplify config file before submitting the issue, such as removing unnecessary inbound / outbound blocks.
1. Try to determine the cause of the issue, for example, replacing socks inbound with http inbound to see if the issue still exists.
1. Config file must be structurally complete.
**Any issue not following the issue template will be closed immediately.**
## Code formatter
Add the following Markdown decorator to config file content:
\`\`\`javascript
{
// config file
}
\`\`\`

6
.gitmodules vendored
View File

@ -1,3 +1,9 @@
[submodule "vendor/h12.me/socks"]
path = vendor/h12.me/socks
url = https://github.com/h12w/socks
[submodule "vendor/github.com/shadowsocks/go-shadowsocks2"]
path = vendor/github.com/shadowsocks/go-shadowsocks2
url = https://github.com/shadowsocks/go-shadowsocks2
[submodule "vendor/github.com/Yawning/chacha20"]
path = vendor/github.com/Yawning/chacha20
url = https://github.com/Yawning/chacha20

View File

@ -1,7 +1,7 @@
sudo: required
language: go
go:
- 1.9
- 1.9.2
go_import_path: v2ray.com/core
git:
depth: 5

19
.vscode/tasks.json vendored
View File

@ -1,13 +1,18 @@
{
"version": "0.1.0",
"version": "2.0.0",
"command": "go",
"isShellCommand": true,
"showOutput": "always",
"type": "shell",
"presentation": {
"echo": true,
"reveal": "always",
"focus": false,
"panel": "shared"
},
"tasks": [
{
"taskName": "build",
"label": "build",
"args": ["v2ray.com/core/..."],
"isBuildCommand": true,
"group": "build",
"problemMatcher": {
"owner": "go",
"fileLocation": ["relative", "${workspaceRoot}"],
@ -20,9 +25,9 @@
}
},
{
"taskName": "test",
"label": "test",
"args": ["-p", "1", "v2ray.com/core/..."],
"isBuildCommand": false
"group": "test"
}
]
}

View File

@ -1,4 +1,4 @@
# Project V2Ray
# Project V
[![Build Status][1]][2] [![codecov.io][3]][4] [![Go Report][5]][6] [![GoDoc][7]][8] [![codebeat][9]][10]
@ -13,11 +13,21 @@
[9]: https://codebeat.co/badges/f2354ca8-3e24-463d-a2e3-159af73b2477 "Codebeat badge"
[10]: https://codebeat.co/projects/github-com-v2ray-v2ray-core-master "Codebeat"
V2Ray 是一个模块化的代理软件包,它的目标是提供常用的代理软件模块,简化网络代理软件的开发。
V 是一个模块化的代理软件包,它的目标是提供常用的代理软件模块,简化网络代理软件的开发。
[官方网站](https://www.v2ray.com/)
V2Ray provides building blocks for network proxy development. Read our [Wiki](https://www.v2ray.com/en/index.html) for more information.
V provides building blocks for network proxy development. Read our [Wiki](https://www.v2ray.com/en/index.html) for more information.
## License
[The MIT License (MIT)](https://raw.githubusercontent.com/v2ray/v2ray-core/master/LICENSE)
## Credits
This repo relies on the following third-party projects:
* In production:
* [miekg/dns](https://github.com/miekg/dns)
* [gorilla/websocket](https://github.com/gorilla/websocket)
* For testing only:
* [h12w/socks](https://github.com/h12w/socks)

View File

@ -39,7 +39,7 @@ func NewDefaultDispatcher(ctx context.Context, config *dispatcher.Config) (*Defa
return nil, newError("no space in context")
}
d := &DefaultDispatcher{}
space.OnInitialize(func() error {
space.On(app.SpaceInitializing, func(interface{}) error {
d.ohm = proxyman.OutboundHandlerManagerFromSpace(space)
if d.ohm == nil {
return newError("OutboundHandlerManager is not found in the space")

View File

@ -3,12 +3,14 @@ package impl_test
import (
"testing"
"v2ray.com/core/app/proxyman"
. "v2ray.com/core/app/dispatcher/impl"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func TestHTTPHeaders(t *testing.T) {
assert := assert.On(t)
assert := With(t)
cases := []struct {
input string
@ -94,13 +96,13 @@ first_name=John&last_name=Doe&action=Submit`,
for _, test := range cases {
domain, err := SniffHTTP([]byte(test.input))
assert.String(domain).Equals(test.domain)
assert.Error(err).Equals(test.err)
assert(domain, Equals, test.domain)
assert(err, Equals, test.err)
}
}
func TestTLSHeaders(t *testing.T) {
assert := assert.On(t)
assert := With(t)
cases := []struct {
input []byte
@ -180,7 +182,13 @@ func TestTLSHeaders(t *testing.T) {
for _, test := range cases {
domain, err := SniffTLS(test.input)
assert.String(domain).Equals(test.domain)
assert.Error(err).Equals(test.err)
assert(domain, Equals, test.domain)
assert(err, Equals, test.err)
}
}
func TestUnknownSniffer(t *testing.T) {
assert := With(t)
assert(func() { NewSniffer([]proxyman.KnownProtocols{proxyman.KnownProtocols(-1)}) }, Panics)
}

View File

@ -8,6 +8,7 @@ import (
"github.com/miekg/dns"
"v2ray.com/core/app/dispatcher"
"v2ray.com/core/app/log"
"v2ray.com/core/common"
"v2ray.com/core/common/buf"
"v2ray.com/core/common/dice"
"v2ray.com/core/common/net"
@ -15,13 +16,16 @@ import (
)
const (
DefaultTTL = uint32(3600)
CleanupInterval = time.Second * 120
CleanupThreshold = 512
)
var (
pseudoDestination = net.UDPDestination(net.LocalHostIP, net.Port(53))
multiQuestionDNS = map[net.Address]bool{
net.IPAddress([]byte{8, 8, 8, 8}): true,
net.IPAddress([]byte{8, 8, 4, 4}): true,
net.IPAddress([]byte{9, 9, 9, 9}): true,
}
)
type ARecord struct {
@ -55,54 +59,52 @@ func NewUDPNameServer(address net.Destination, dispatcher dispatcher.Interface)
return s
}
// Private: Visible for testing.
func (v *UDPNameServer) Cleanup() {
func (s *UDPNameServer) Cleanup() {
expiredRequests := make([]uint16, 0, 16)
now := time.Now()
v.Lock()
for id, r := range v.requests {
s.Lock()
for id, r := range s.requests {
if r.expire.Before(now) {
expiredRequests = append(expiredRequests, id)
close(r.response)
}
}
for _, id := range expiredRequests {
delete(v.requests, id)
delete(s.requests, id)
}
v.Unlock()
expiredRequests = nil
s.Unlock()
}
// Private: Visible for testing.
func (v *UDPNameServer) AssignUnusedID(response chan<- *ARecord) uint16 {
func (s *UDPNameServer) AssignUnusedID(response chan<- *ARecord) uint16 {
var id uint16
v.Lock()
if len(v.requests) > CleanupThreshold && v.nextCleanup.Before(time.Now()) {
v.nextCleanup = time.Now().Add(CleanupInterval)
go v.Cleanup()
s.Lock()
if len(s.requests) > CleanupThreshold && s.nextCleanup.Before(time.Now()) {
s.nextCleanup = time.Now().Add(CleanupInterval)
go s.Cleanup()
}
for {
id = dice.RollUint16()
if _, found := v.requests[id]; found {
if _, found := s.requests[id]; found {
continue
}
log.Trace(newError("add pending request id ", id).AtDebug())
v.requests[id] = &PendingRequest{
s.requests[id] = &PendingRequest{
expire: time.Now().Add(time.Second * 8),
response: response,
}
break
}
v.Unlock()
s.Unlock()
return id
}
// Private: Visible for testing.
func (v *UDPNameServer) HandleResponse(payload *buf.Buffer) {
func (s *UDPNameServer) HandleResponse(payload *buf.Buffer) {
msg := new(dns.Msg)
err := msg.Unpack(payload.Bytes())
if err != nil {
if err == dns.ErrTruncated {
log.Trace(newError("truncated message received. DNS server should still work. If you see anything abnormal, please submit an issue to v2ray-core.").AtWarning())
} else if err != nil {
log.Trace(newError("failed to parse DNS response").Base(err).AtWarning())
return
}
@ -110,17 +112,17 @@ func (v *UDPNameServer) HandleResponse(payload *buf.Buffer) {
IPs: make([]net.IP, 0, 16),
}
id := msg.Id
ttl := DefaultTTL
log.Trace(newError("handling response for id ", id, " content: ", msg.String()).AtDebug())
ttl := uint32(3600) // an hour
log.Trace(newError("handling response for id ", id, " content: ", msg).AtDebug())
v.Lock()
request, found := v.requests[id]
s.Lock()
request, found := s.requests[id]
if !found {
v.Unlock()
s.Unlock()
return
}
delete(v.requests, id)
v.Unlock()
delete(s.requests, id)
s.Unlock()
for _, rr := range msg.Answer {
switch rr := rr.(type) {
@ -142,8 +144,7 @@ func (v *UDPNameServer) HandleResponse(payload *buf.Buffer) {
close(request.response)
}
func (v *UDPNameServer) BuildQueryA(domain string, id uint16) *buf.Buffer {
func (s *UDPNameServer) BuildQueryA(domain string, id uint16) *buf.Buffer {
msg := new(dns.Msg)
msg.Id = id
msg.RecursionDesired = true
@ -153,34 +154,40 @@ func (v *UDPNameServer) BuildQueryA(domain string, id uint16) *buf.Buffer {
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}}
if multiQuestionDNS[s.address.Address] {
msg.Question = append(msg.Question, dns.Question{
Name: dns.Fqdn(domain),
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
})
}
buffer := buf.New()
buffer.AppendSupplier(func(b []byte) (int, error) {
common.Must(buffer.Reset(func(b []byte) (int, error) {
writtenBuffer, err := msg.PackBuffer(b)
return len(writtenBuffer), err
})
}))
return buffer
}
func (v *UDPNameServer) QueryA(domain string) <-chan *ARecord {
func (s *UDPNameServer) QueryA(domain string) <-chan *ARecord {
response := make(chan *ARecord, 1)
id := v.AssignUnusedID(response)
id := s.AssignUnusedID(response)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*8)
v.udpServer.Dispatch(ctx, v.address, v.BuildQueryA(domain, id), v.HandleResponse)
ctx, cancel := context.WithCancel(context.Background())
s.udpServer.Dispatch(ctx, s.address, s.BuildQueryA(domain, id), s.HandleResponse)
go func() {
for i := 0; i < 2; i++ {
time.Sleep(time.Second)
v.Lock()
_, found := v.requests[id]
v.Unlock()
if found {
v.udpServer.Dispatch(ctx, v.address, v.BuildQueryA(domain, id), v.HandleResponse)
} else {
s.Lock()
_, found := s.requests[id]
s.Unlock()
if !found {
break
}
s.udpServer.Dispatch(ctx, s.address, s.BuildQueryA(domain, id), s.HandleResponse)
}
cancel()
}()
@ -191,7 +198,7 @@ func (v *UDPNameServer) QueryA(domain string) <-chan *ARecord {
type LocalNameServer struct {
}
func (v *LocalNameServer) QueryA(domain string) <-chan *ARecord {
func (*LocalNameServer) QueryA(domain string) <-chan *ARecord {
response := make(chan *ARecord, 1)
go func() {
@ -205,7 +212,7 @@ func (v *LocalNameServer) QueryA(domain string) <-chan *ARecord {
response <- &ARecord{
IPs: ips,
Expire: time.Now().Add(time.Second * time.Duration(DefaultTTL)),
Expire: time.Now().Add(time.Hour),
}
}()

View File

@ -21,11 +21,22 @@ const (
)
type DomainRecord struct {
A *ARecord
IP []net.IP
Expire time.Time
LastAccess time.Time
}
func (r *DomainRecord) Expired() bool {
return r.Expire.Before(time.Now())
}
func (r *DomainRecord) Inactive() bool {
now := time.Now()
return r.Expire.Before(now) || r.LastAccess.Add(time.Minute*5).Before(now)
}
type CacheServer struct {
sync.RWMutex
sync.Mutex
hosts map[string]net.IP
records map[string]*DomainRecord
servers []NameServer
@ -41,7 +52,7 @@ func NewCacheServer(ctx context.Context, config *dns.Config) (*CacheServer, erro
servers: make([]NameServer, len(config.NameServers)),
hosts: config.GetInternalHosts(),
}
space.OnInitialize(func() error {
space.On(app.SpaceInitializing, func(interface{}) error {
disp := dispatcher.FromSpace(space)
if disp == nil {
return newError("dispatcher is not found in the space")
@ -79,15 +90,33 @@ func (*CacheServer) Start() error {
func (*CacheServer) Close() {}
func (s *CacheServer) GetCached(domain string) []net.IP {
s.RLock()
defer s.RUnlock()
s.Lock()
defer s.Unlock()
if record, found := s.records[domain]; found && record.A.Expire.After(time.Now()) {
return record.A.IPs
if record, found := s.records[domain]; found && !record.Expired() {
record.LastAccess = time.Now()
return record.IP
}
return nil
}
func (s *CacheServer) tryCleanup() {
s.Lock()
defer s.Unlock()
if len(s.records) > 256 {
domains := make([]string, 0, 256)
for d, r := range s.records {
if r.Expired() {
domains = append(domains, d)
}
}
for _, d := range domains {
delete(s.records, d)
}
}
}
func (s *CacheServer) Get(domain string) []net.IP {
if ip, found := s.hosts[domain]; found {
return []net.IP{ip}
@ -99,6 +128,8 @@ func (s *CacheServer) Get(domain string) []net.IP {
return ips
}
s.tryCleanup()
for _, server := range s.servers {
response := server.QueryA(domain)
select {
@ -108,7 +139,9 @@ func (s *CacheServer) Get(domain string) []net.IP {
}
s.Lock()
s.records[domain] = &DomainRecord{
A: a,
IP: a.IPs,
Expire: a.Expire,
LastAccess: time.Now(),
}
s.Unlock()
log.Trace(newError("returning ", len(a.IPs), " IPs for domain ", domain).AtDebug())

View File

@ -0,0 +1,105 @@
package server_test
import (
"context"
"testing"
"v2ray.com/core/app"
"v2ray.com/core/app/dispatcher"
_ "v2ray.com/core/app/dispatcher/impl"
. "v2ray.com/core/app/dns"
_ "v2ray.com/core/app/dns/server"
"v2ray.com/core/app/policy"
_ "v2ray.com/core/app/policy/manager"
"v2ray.com/core/app/proxyman"
_ "v2ray.com/core/app/proxyman/outbound"
"v2ray.com/core/common"
"v2ray.com/core/common/net"
"v2ray.com/core/common/serial"
"v2ray.com/core/proxy/freedom"
"v2ray.com/core/testing/servers/udp"
. "v2ray.com/ext/assert"
"github.com/miekg/dns"
)
type staticHandler struct {
}
func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
ans := new(dns.Msg)
ans.Id = r.Id
for _, q := range r.Question {
if q.Name == "google.com." && q.Qtype == dns.TypeA {
rr, _ := dns.NewRR("google.com. IN A 8.8.8.8")
ans.Answer = append(ans.Answer, rr)
} else if q.Name == "facebook.com." && q.Qtype == dns.TypeA {
rr, _ := dns.NewRR("facebook.com. IN A 9.9.9.9")
ans.Answer = append(ans.Answer, rr)
}
}
w.WriteMsg(ans)
}
func TestUDPServer(t *testing.T) {
assert := With(t)
port := udp.PickPort()
dnsServer := dns.Server{
Addr: "127.0.0.1:" + port.String(),
Net: "udp",
Handler: &staticHandler{},
UDPSize: 1200,
}
go dnsServer.ListenAndServe()
config := &Config{
NameServers: []*net.Endpoint{
{
Network: net.Network_UDP,
Address: &net.IPOrDomain{
Address: &net.IPOrDomain_Ip{
Ip: []byte{127, 0, 0, 1},
},
},
Port: uint32(port),
},
},
}
ctx := context.Background()
space := app.NewSpace()
ctx = app.ContextWithSpace(ctx, space)
common.Must(app.AddApplicationToSpace(ctx, config))
common.Must(app.AddApplicationToSpace(ctx, &dispatcher.Config{}))
common.Must(app.AddApplicationToSpace(ctx, &proxyman.OutboundConfig{}))
common.Must(app.AddApplicationToSpace(ctx, &policy.Config{}))
om := proxyman.OutboundHandlerManagerFromSpace(space)
om.AddHandler(ctx, &proxyman.OutboundHandlerConfig{
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
})
common.Must(space.Initialize())
common.Must(space.Start())
server := FromSpace(space)
assert(server, IsNotNil)
ips := server.Get("google.com")
assert(len(ips), Equals, 1)
assert([]byte(ips[0]), Equals, []byte{8, 8, 8, 8})
ips = server.Get("facebook.com")
assert(len(ips), Equals, 1)
assert([]byte(ips[0]), Equals, []byte{9, 9, 9, 9})
dnsServer.Shutdown()
ips = server.Get("google.com")
assert(len(ips), Equals, 1)
assert([]byte(ips[0]), Equals, []byte{8, 8, 8, 8})
}

View File

@ -4,11 +4,11 @@ import (
"testing"
. "v2ray.com/core/app/log/internal"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func TestAccessLog(t *testing.T) {
assert := assert.On(t)
assert := With(t)
entry := &AccessLog{
From: "test_from",
@ -18,8 +18,8 @@ func TestAccessLog(t *testing.T) {
}
entryStr := entry.String()
assert.String(entryStr).Contains("test_from")
assert.String(entryStr).Contains("test_to")
assert.String(entryStr).Contains("test_reason")
assert.String(entryStr).Contains("Accepted")
assert(entryStr, HasSubstring, "test_from")
assert(entryStr, HasSubstring, "test_to")
assert(entryStr, HasSubstring, "test_reason")
assert(entryStr, HasSubstring, "Accepted")
}

26
app/policy/config.go Normal file
View File

@ -0,0 +1,26 @@
package policy
import (
"time"
)
func (s *Second) Duration() time.Duration {
return time.Second * time.Duration(s.Value)
}
func (p *Policy) OverrideWith(another *Policy) {
if another.Timeout != nil {
if another.Timeout.Handshake != nil {
p.Timeout.Handshake = another.Timeout.Handshake
}
if another.Timeout.ConnectionIdle != nil {
p.Timeout.ConnectionIdle = another.Timeout.ConnectionIdle
}
if another.Timeout.UplinkOnly != nil {
p.Timeout.UplinkOnly = another.Timeout.UplinkOnly
}
if another.Timeout.DownlinkOnly != nil {
p.Timeout.DownlinkOnly = another.Timeout.DownlinkOnly
}
}
}

140
app/policy/config.pb.go Normal file
View File

@ -0,0 +1,140 @@
package policy
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type Second struct {
Value uint32 `protobuf:"varint,1,opt,name=value" json:"value,omitempty"`
}
func (m *Second) Reset() { *m = Second{} }
func (m *Second) String() string { return proto.CompactTextString(m) }
func (*Second) ProtoMessage() {}
func (*Second) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
func (m *Second) GetValue() uint32 {
if m != nil {
return m.Value
}
return 0
}
type Policy struct {
Timeout *Policy_Timeout `protobuf:"bytes,1,opt,name=timeout" json:"timeout,omitempty"`
}
func (m *Policy) Reset() { *m = Policy{} }
func (m *Policy) String() string { return proto.CompactTextString(m) }
func (*Policy) ProtoMessage() {}
func (*Policy) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
func (m *Policy) GetTimeout() *Policy_Timeout {
if m != nil {
return m.Timeout
}
return nil
}
// Timeout is a message for timeout settings in various stages, in seconds.
type Policy_Timeout struct {
Handshake *Second `protobuf:"bytes,1,opt,name=handshake" json:"handshake,omitempty"`
ConnectionIdle *Second `protobuf:"bytes,2,opt,name=connection_idle,json=connectionIdle" json:"connection_idle,omitempty"`
UplinkOnly *Second `protobuf:"bytes,3,opt,name=uplink_only,json=uplinkOnly" json:"uplink_only,omitempty"`
DownlinkOnly *Second `protobuf:"bytes,4,opt,name=downlink_only,json=downlinkOnly" json:"downlink_only,omitempty"`
}
func (m *Policy_Timeout) Reset() { *m = Policy_Timeout{} }
func (m *Policy_Timeout) String() string { return proto.CompactTextString(m) }
func (*Policy_Timeout) ProtoMessage() {}
func (*Policy_Timeout) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1, 0} }
func (m *Policy_Timeout) GetHandshake() *Second {
if m != nil {
return m.Handshake
}
return nil
}
func (m *Policy_Timeout) GetConnectionIdle() *Second {
if m != nil {
return m.ConnectionIdle
}
return nil
}
func (m *Policy_Timeout) GetUplinkOnly() *Second {
if m != nil {
return m.UplinkOnly
}
return nil
}
func (m *Policy_Timeout) GetDownlinkOnly() *Second {
if m != nil {
return m.DownlinkOnly
}
return nil
}
type Config struct {
Level map[uint32]*Policy `protobuf:"bytes,1,rep,name=level" json:"level,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
}
func (m *Config) Reset() { *m = Config{} }
func (m *Config) String() string { return proto.CompactTextString(m) }
func (*Config) ProtoMessage() {}
func (*Config) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
func (m *Config) GetLevel() map[uint32]*Policy {
if m != nil {
return m.Level
}
return nil
}
func init() {
proto.RegisterType((*Second)(nil), "v2ray.core.app.policy.Second")
proto.RegisterType((*Policy)(nil), "v2ray.core.app.policy.Policy")
proto.RegisterType((*Policy_Timeout)(nil), "v2ray.core.app.policy.Policy.Timeout")
proto.RegisterType((*Config)(nil), "v2ray.core.app.policy.Config")
}
func init() { proto.RegisterFile("v2ray.com/core/app/policy/config.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
// 349 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xc1, 0x4a, 0xeb, 0x40,
0x14, 0x86, 0x49, 0x7a, 0x9b, 0x72, 0x4f, 0x6f, 0xaf, 0x32, 0x58, 0x88, 0x05, 0xa5, 0x14, 0x94,
0xae, 0x26, 0x90, 0x6e, 0x44, 0xb1, 0x62, 0x45, 0x41, 0x10, 0x2c, 0x51, 0x14, 0xdc, 0x94, 0x71,
0x32, 0xda, 0xd0, 0xe9, 0x9c, 0x21, 0xa6, 0x95, 0xbc, 0x86, 0x6f, 0xe0, 0xd6, 0x87, 0xf2, 0x59,
0x24, 0x99, 0x84, 0x6c, 0x5a, 0xe9, 0x6e, 0x72, 0xf8, 0xfe, 0x8f, 0x43, 0xfe, 0x03, 0x87, 0x4b,
0x3f, 0x66, 0x29, 0xe5, 0x38, 0xf7, 0x38, 0xc6, 0xc2, 0x63, 0x5a, 0x7b, 0x1a, 0x65, 0xc4, 0x53,
0x8f, 0xa3, 0x7a, 0x89, 0x5e, 0xa9, 0x8e, 0x31, 0x41, 0xd2, 0x2e, 0xb9, 0x58, 0x50, 0xa6, 0x35,
0x35, 0x4c, 0x6f, 0x1f, 0x9c, 0x3b, 0xc1, 0x51, 0x85, 0x64, 0x07, 0xea, 0x4b, 0x26, 0x17, 0xc2,
0xb5, 0xba, 0x56, 0xbf, 0x15, 0x98, 0x8f, 0xde, 0xb7, 0x0d, 0xce, 0x38, 0x47, 0xc9, 0x19, 0x34,
0x92, 0x68, 0x2e, 0x70, 0x91, 0xe4, 0x48, 0xd3, 0x3f, 0xa0, 0x2b, 0x9d, 0xd4, 0xf0, 0xf4, 0xde,
0xc0, 0x41, 0x99, 0xea, 0x7c, 0xd8, 0xd0, 0x28, 0x86, 0xe4, 0x04, 0xfe, 0x4e, 0x99, 0x0a, 0xdf,
0xa6, 0x6c, 0x26, 0x0a, 0xdd, 0xde, 0x1a, 0x9d, 0xd9, 0x2f, 0xa8, 0x78, 0x72, 0x05, 0x5b, 0x1c,
0x95, 0x12, 0x3c, 0x89, 0x50, 0x4d, 0xa2, 0x50, 0x0a, 0xd7, 0xde, 0x44, 0xf1, 0xbf, 0x4a, 0x5d,
0x87, 0x52, 0x90, 0x21, 0x34, 0x17, 0x5a, 0x46, 0x6a, 0x36, 0x41, 0x25, 0x53, 0xb7, 0xb6, 0x89,
0x03, 0x4c, 0xe2, 0x56, 0xc9, 0x94, 0x8c, 0xa0, 0x15, 0xe2, 0xbb, 0xaa, 0x0c, 0x7f, 0x36, 0x31,
0xfc, 0x2b, 0x33, 0x99, 0xa3, 0xf7, 0x69, 0x81, 0x73, 0x91, 0x17, 0x45, 0x86, 0x50, 0x97, 0x62,
0x29, 0xa4, 0x6b, 0x75, 0x6b, 0xfd, 0xa6, 0xdf, 0x5f, 0xa3, 0x31, 0x34, 0xbd, 0xc9, 0xd0, 0x4b,
0x95, 0xc4, 0x69, 0x60, 0x62, 0x9d, 0x47, 0x80, 0x6a, 0x48, 0xb6, 0xa1, 0x36, 0x13, 0x69, 0xd1,
0x66, 0xf6, 0x24, 0x83, 0xb2, 0xe1, 0xdf, 0x7f, 0x96, 0xa9, 0xaf, 0x38, 0x80, 0x63, 0xfb, 0xc8,
0x1a, 0x9d, 0xc2, 0x2e, 0xc7, 0xf9, 0x6a, 0x7c, 0x6c, 0x3d, 0x39, 0xe6, 0xf5, 0x65, 0xb7, 0x1f,
0xfc, 0x80, 0x65, 0x0b, 0xc6, 0x82, 0x9e, 0x6b, 0x5d, 0x98, 0x9e, 0x9d, 0xfc, 0x02, 0x07, 0x3f,
0x01, 0x00, 0x00, 0xff, 0xff, 0xcf, 0x25, 0x25, 0xc2, 0xab, 0x02, 0x00, 0x00,
}

27
app/policy/config.proto Normal file
View File

@ -0,0 +1,27 @@
syntax = "proto3";
package v2ray.core.app.policy;
option csharp_namespace = "V2Ray.Core.App.Policy";
option go_package = "policy";
option java_package = "com.v2ray.core.app.policy";
option java_multiple_files = true;
message Second {
uint32 value = 1;
}
message Policy {
// Timeout is a message for timeout settings in various stages, in seconds.
message Timeout {
Second handshake = 1;
Second connection_idle = 2;
Second uplink_only = 3;
Second downlink_only = 4;
}
Timeout timeout = 1;
}
message Config {
map<uint32, Policy> level = 1;
}

View File

@ -0,0 +1,66 @@
package manager
import (
"context"
"v2ray.com/core/app/policy"
"v2ray.com/core/common"
)
type Instance struct {
levels map[uint32]*policy.Policy
}
func New(ctx context.Context, config *policy.Config) (*Instance, error) {
levels := config.Level
if levels == nil {
levels = make(map[uint32]*policy.Policy)
}
for _, p := range levels {
g := global()
g.OverrideWith(p)
*p = g
}
return &Instance{
levels: levels,
}, nil
}
func global() policy.Policy {
return policy.Policy{
Timeout: &policy.Policy_Timeout{
Handshake: &policy.Second{Value: 4},
ConnectionIdle: &policy.Second{Value: 300},
UplinkOnly: &policy.Second{Value: 5},
DownlinkOnly: &policy.Second{Value: 30},
},
}
}
// GetPolicy implements policy.Manager.
func (m *Instance) GetPolicy(level uint32) policy.Policy {
if p, ok := m.levels[level]; ok {
return *p
}
return global()
}
// Start implements app.Application.Start().
func (m *Instance) Start() error {
return nil
}
// Close implements app.Application.Close().
func (m *Instance) Close() {
}
// Interface implement app.Application.Interface().
func (m *Instance) Interface() interface{} {
return (*policy.Manager)(nil)
}
func init() {
common.Must(common.RegisterConfig((*policy.Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
return New(ctx, config.(*policy.Config))
}))
}

20
app/policy/policy.go Normal file
View File

@ -0,0 +1,20 @@
package policy
import (
"v2ray.com/core/app"
)
// Manager is an utility to manage policy per user level.
type Manager interface {
// GetPolicy returns the Policy for the given user level.
GetPolicy(level uint32) Policy
}
// FromSpace returns the policy.Manager in a space.
func FromSpace(space app.Space) Manager {
app := space.GetApplication((*Manager)(nil))
if app == nil {
return nil
}
return app.(Manager)
}

View File

@ -172,6 +172,11 @@ func (*udpConn) SetWriteDeadline(time.Time) error {
return nil
}
type connId struct {
src net.Destination
dest net.Destination
}
type udpWorker struct {
sync.RWMutex
@ -185,39 +190,43 @@ type udpWorker struct {
ctx context.Context
cancel context.CancelFunc
activeConn map[net.Destination]*udpConn
activeConn map[connId]*udpConn
}
func (w *udpWorker) getConnection(src net.Destination) (*udpConn, bool) {
func (w *udpWorker) getConnection(id connId) (*udpConn, bool) {
w.Lock()
defer w.Unlock()
if conn, found := w.activeConn[src]; found {
if conn, found := w.activeConn[id]; found {
return conn, true
}
conn := &udpConn{
input: make(chan *buf.Buffer, 32),
output: func(b []byte) (int, error) {
return w.hub.WriteTo(b, src)
return w.hub.WriteTo(b, id.src)
},
remote: &net.UDPAddr{
IP: src.Address.IP(),
Port: int(src.Port),
IP: id.src.Address.IP(),
Port: int(id.src.Port),
},
local: &net.UDPAddr{
IP: w.address.IP(),
Port: int(w.port),
},
}
w.activeConn[src] = conn
w.activeConn[id] = conn
conn.updateActivity()
return conn, false
}
func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest net.Destination) {
conn, existing := w.getConnection(source)
id := connId{
src: source,
dest: originalDest,
}
conn, existing := w.getConnection(id)
select {
case conn.input <- b:
default:
@ -240,20 +249,20 @@ func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest
if err := w.proxy.Process(ctx, net.Network_UDP, conn, w.dispatcher); err != nil {
log.Trace(newError("connection ends").Base(err))
}
w.removeConn(source)
w.removeConn(id)
cancel()
}()
}
}
func (w *udpWorker) removeConn(src net.Destination) {
func (w *udpWorker) removeConn(id connId) {
w.Lock()
delete(w.activeConn, src)
delete(w.activeConn, id)
w.Unlock()
}
func (w *udpWorker) Start() error {
w.activeConn = make(map[net.Destination]*udpConn)
w.activeConn = make(map[connId]*udpConn, 16)
ctx, cancel := context.WithCancel(context.Background())
w.ctx = ctx
w.cancel = cancel

View File

@ -1,8 +1,10 @@
package mux
import (
"v2ray.com/core/common/bitmask"
"v2ray.com/core/common/buf"
"v2ray.com/core/common/net"
"v2ray.com/core/common/protocol"
"v2ray.com/core/common/serial"
)
@ -15,24 +17,10 @@ const (
SessionStatusKeepAlive SessionStatus = 0x04
)
type Option byte
const (
OptionData Option = 0x01
OptionData bitmask.Byte = 0x01
)
func (o Option) Has(x Option) bool {
return (o & x) == x
}
func (o *Option) Add(x Option) {
*o = (*o | x)
}
func (o *Option) Clear(x Option) {
*o = (*o & (^x))
}
type TargetNetwork byte
const (
@ -40,14 +28,6 @@ const (
TargetNetworkUDP TargetNetwork = 0x02
)
type AddressType byte
const (
AddressTypeIPv4 AddressType = 0x01
AddressTypeDomain AddressType = 0x02
AddressTypeIPv6 AddressType = 0x03
)
/*
Frame format
2 bytes - length
@ -62,10 +42,10 @@ n bytes - address
*/
type FrameMetadata struct {
SessionID uint16
SessionStatus SessionStatus
Target net.Destination
Option Option
SessionID uint16
Option bitmask.Byte
SessionStatus SessionStatus
}
func (f FrameMetadata) AsSupplier() buf.Supplier {
@ -92,17 +72,21 @@ func (f FrameMetadata) AsSupplier() buf.Supplier {
addr := f.Target.Address
switch addr.Family() {
case net.AddressFamilyIPv4:
b = append(b, byte(AddressTypeIPv4))
b = append(b, byte(protocol.AddressTypeIPv4))
b = append(b, addr.IP()...)
length += 5
case net.AddressFamilyIPv6:
b = append(b, byte(AddressTypeIPv6))
b = append(b, byte(protocol.AddressTypeIPv6))
b = append(b, addr.IP()...)
length += 17
case net.AddressFamilyDomain:
nDomain := len(addr.Domain())
b = append(b, byte(AddressTypeDomain), byte(nDomain))
b = append(b, addr.Domain()...)
domain := addr.Domain()
if protocol.IsDomainTooLong(domain) {
return 0, newError("domain name too long: ", domain)
}
nDomain := len(domain)
b = append(b, byte(protocol.AddressTypeDomain), byte(nDomain))
b = append(b, domain...)
length += nDomain + 2
}
}
@ -120,7 +104,7 @@ func ReadFrameFrom(b []byte) (*FrameMetadata, error) {
f := &FrameMetadata{
SessionID: serial.BytesToUint16(b[:2]),
SessionStatus: SessionStatus(b[2]),
Option: Option(b[3]),
Option: bitmask.Byte(b[3]),
}
b = b[4:]
@ -128,18 +112,18 @@ func ReadFrameFrom(b []byte) (*FrameMetadata, error) {
if f.SessionStatus == SessionStatusNew {
network := TargetNetwork(b[0])
port := net.PortFromBytes(b[1:3])
addrType := AddressType(b[3])
addrType := protocol.AddressType(b[3])
b = b[4:]
var addr net.Address
switch addrType {
case AddressTypeIPv4:
case protocol.AddressTypeIPv4:
addr = net.IPAddress(b[0:4])
b = b[4:]
case AddressTypeIPv6:
case protocol.AddressTypeIPv6:
addr = net.IPAddress(b[0:16])
b = b[16:]
case AddressTypeDomain:
case protocol.AddressTypeDomain:
nDomain := int(b[0])
addr = net.DomainAddress(string(b[1 : 1+nDomain]))
b = b[nDomain+1:]

View File

@ -90,7 +90,19 @@ func NewClient(p proxy.Outbound, dialer proxy.Dialer, m *ClientManager) (*Client
ctx, cancel := context.WithCancel(context.Background())
ctx = proxy.ContextWithTarget(ctx, net.TCPDestination(muxCoolAddress, muxCoolPort))
pipe := ray.NewRay(ctx)
go p.Process(ctx, pipe, dialer)
go func() {
if err := p.Process(ctx, pipe, dialer); err != nil {
cancel()
traceErr := errors.New("failed to handler mux client connection").Base(err)
if err != io.EOF && err != context.Canceled {
traceErr = traceErr.AtWarning()
}
log.Trace(traceErr)
}
}()
c := &Client{
sessionManager: NewSessionManager(),
inboundRay: pipe,
@ -104,6 +116,7 @@ func NewClient(p proxy.Outbound, dialer proxy.Dialer, m *ClientManager) (*Client
return c, nil
}
// Closed returns true if this Client is closed.
func (m *Client) Closed() bool {
select {
case <-m.ctx.Done():
@ -148,7 +161,7 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
log.Trace(newError("dispatching request to ", dest))
data, _ := s.input.ReadTimeout(time.Millisecond * 500)
if err := writer.Write(data); err != nil {
if err := writer.WriteMultiBuffer(data); err != nil {
log.Trace(newError("failed to write first payload").Base(err))
return
}
@ -179,26 +192,25 @@ func (m *Client) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) bool
return true
}
func drain(reader io.Reader) error {
buf.Copy(NewStreamReader(reader), buf.Discard)
return nil
func drain(reader *buf.BufferedReader) error {
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
func (m *Client) handleStatueKeepAlive(meta *FrameMetadata, reader io.Reader) error {
func (m *Client) handleStatueKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error {
if meta.Option.Has(OptionData) {
return drain(reader)
}
return nil
}
func (m *Client) handleStatusNew(meta *FrameMetadata, reader io.Reader) error {
func (m *Client) handleStatusNew(meta *FrameMetadata, reader *buf.BufferedReader) error {
if meta.Option.Has(OptionData) {
return drain(reader)
}
return nil
}
func (m *Client) handleStatusKeep(meta *FrameMetadata, reader io.Reader) error {
func (m *Client) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error {
if !meta.Option.Has(OptionData) {
return nil
}
@ -209,7 +221,7 @@ func (m *Client) handleStatusKeep(meta *FrameMetadata, reader io.Reader) error {
return drain(reader)
}
func (m *Client) handleStatusEnd(meta *FrameMetadata, reader io.Reader) error {
func (m *Client) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error {
if s, found := m.sessionManager.Get(meta.SessionID); found {
s.Close()
}
@ -222,11 +234,10 @@ func (m *Client) handleStatusEnd(meta *FrameMetadata, reader io.Reader) error {
func (m *Client) fetchOutput() {
defer m.cancel()
reader := buf.ToBytesReader(m.inboundRay.InboundOutput())
metaReader := NewMetadataReader(reader)
reader := buf.NewBufferedReader(m.inboundRay.InboundOutput())
for {
meta, err := metaReader.Read()
meta, err := ReadMetadata(reader)
if err != nil {
if errors.Cause(err) != io.EOF {
log.Trace(newError("failed to read metadata").Base(err))
@ -263,7 +274,7 @@ type Server struct {
func NewServer(ctx context.Context) *Server {
s := &Server{}
space := app.SpaceFromContext(ctx)
space.OnInitialize(func() error {
space.On(app.SpaceInitializing, func(interface{}) error {
d := dispatcher.FromSpace(space)
if d == nil {
return newError("no dispatcher in space")
@ -304,14 +315,14 @@ func handle(ctx context.Context, s *Session, output buf.Writer) {
s.Close()
}
func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader io.Reader) error {
func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error {
if meta.Option.Has(OptionData) {
return drain(reader)
}
return nil
}
func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, reader io.Reader) error {
func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, reader *buf.BufferedReader) error {
log.Trace(newError("received request for ", meta.Target))
inboundRay, err := w.dispatcher.Dispatch(ctx, meta.Target)
if err != nil {
@ -338,7 +349,7 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata,
return nil
}
func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader io.Reader) error {
func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error {
if !meta.Option.Has(OptionData) {
return nil
}
@ -348,7 +359,7 @@ func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader io.Reader) e
return drain(reader)
}
func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader io.Reader) error {
func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error {
if s, found := w.sessionManager.Get(meta.SessionID); found {
s.Close()
}
@ -358,9 +369,8 @@ func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader io.Reader) er
return nil
}
func (w *ServerWorker) handleFrame(ctx context.Context, reader io.Reader) error {
metaReader := NewMetadataReader(reader)
meta, err := metaReader.Read()
func (w *ServerWorker) handleFrame(ctx context.Context, reader *buf.BufferedReader) error {
meta, err := ReadMetadata(reader)
if err != nil {
return newError("failed to read metadata").Base(err)
}
@ -386,7 +396,7 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader io.Reader) error
func (w *ServerWorker) run(ctx context.Context) {
input := w.outboundRay.OutboundInput()
reader := buf.ToBytesReader(input)
reader := buf.NewBufferedReader(input)
defer w.sessionManager.Close()

View File

@ -9,14 +9,14 @@ import (
"v2ray.com/core/common/buf"
"v2ray.com/core/common/net"
"v2ray.com/core/common/protocol"
"v2ray.com/core/testing/assert"
"v2ray.com/core/transport/ray"
. "v2ray.com/ext/assert"
)
func readAll(reader buf.Reader) (buf.MultiBuffer, error) {
mb := buf.NewMultiBuffer()
var mb buf.MultiBuffer
for {
b, err := reader.Read()
b, err := reader.ReadMultiBuffer()
if err == io.EOF {
break
}
@ -29,7 +29,7 @@ func readAll(reader buf.Reader) (buf.MultiBuffer, error) {
}
func TestReaderWriter(t *testing.T) {
assert := assert.On(t)
assert := With(t)
stream := ray.NewStream(context.Background())
@ -45,98 +45,98 @@ func TestReaderWriter(t *testing.T) {
writePayload := func(writer *Writer, payload ...byte) error {
b := buf.New()
b.Append(payload)
return writer.Write(buf.NewMultiBufferValue(b))
return writer.WriteMultiBuffer(buf.NewMultiBufferValue(b))
}
assert.Error(writePayload(writer, 'a', 'b', 'c', 'd')).IsNil()
assert.Error(writePayload(writer2)).IsNil()
assert(writePayload(writer, 'a', 'b', 'c', 'd'), IsNil)
assert(writePayload(writer2), IsNil)
assert.Error(writePayload(writer, 'e', 'f', 'g', 'h')).IsNil()
assert.Error(writePayload(writer3, 'x')).IsNil()
assert(writePayload(writer, 'e', 'f', 'g', 'h'), IsNil)
assert(writePayload(writer3, 'x'), IsNil)
writer.Close()
writer3.Close()
assert.Error(writePayload(writer2, 'y')).IsNil()
assert(writePayload(writer2, 'y'), IsNil)
writer2.Close()
bytesReader := buf.ToBytesReader(stream)
metaReader := NewMetadataReader(bytesReader)
bytesReader := buf.NewBufferedReader(stream)
streamReader := NewStreamReader(bytesReader)
meta, err := metaReader.Read()
assert.Error(err).IsNil()
assert.Uint16(meta.SessionID).Equals(1)
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusNew))
assert.Destination(meta.Target).Equals(dest)
assert.Byte(byte(meta.Option)).Equals(byte(OptionData))
meta, err := ReadMetadata(bytesReader)
assert(err, IsNil)
assert(meta.SessionID, Equals, uint16(1))
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusNew))
assert(meta.Target, Equals, dest)
assert(byte(meta.Option), Equals, byte(OptionData))
data, err := readAll(streamReader)
assert.Error(err).IsNil()
assert.Int(len(data)).Equals(1)
assert.String(data[0].String()).Equals("abcd")
assert(err, IsNil)
assert(len(data), Equals, 1)
assert(data[0].String(), Equals, "abcd")
meta, err = metaReader.Read()
assert.Error(err).IsNil()
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusNew))
assert.Uint16(meta.SessionID).Equals(2)
assert.Byte(byte(meta.Option)).Equals(0)
assert.Destination(meta.Target).Equals(dest2)
meta, err = ReadMetadata(bytesReader)
assert(err, IsNil)
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusNew))
assert(meta.SessionID, Equals, uint16(2))
assert(byte(meta.Option), Equals, byte(0))
assert(meta.Target, Equals, dest2)
meta, err = metaReader.Read()
assert.Error(err).IsNil()
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusKeep))
assert.Uint16(meta.SessionID).Equals(1)
assert.Byte(byte(meta.Option)).Equals(1)
meta, err = ReadMetadata(bytesReader)
assert(err, IsNil)
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusKeep))
assert(meta.SessionID, Equals, uint16(1))
assert(byte(meta.Option), Equals, byte(1))
data, err = readAll(streamReader)
assert.Error(err).IsNil()
assert.Int(len(data)).Equals(1)
assert.String(data[0].String()).Equals("efgh")
assert(err, IsNil)
assert(len(data), Equals, 1)
assert(data[0].String(), Equals, "efgh")
meta, err = metaReader.Read()
assert.Error(err).IsNil()
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusNew))
assert.Uint16(meta.SessionID).Equals(3)
assert.Byte(byte(meta.Option)).Equals(1)
assert.Destination(meta.Target).Equals(dest3)
meta, err = ReadMetadata(bytesReader)
assert(err, IsNil)
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusNew))
assert(meta.SessionID, Equals, uint16(3))
assert(byte(meta.Option), Equals, byte(1))
assert(meta.Target, Equals, dest3)
data, err = readAll(streamReader)
assert.Error(err).IsNil()
assert.Int(len(data)).Equals(1)
assert.String(data[0].String()).Equals("x")
assert(err, IsNil)
assert(len(data), Equals, 1)
assert(data[0].String(), Equals, "x")
meta, err = metaReader.Read()
assert.Error(err).IsNil()
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusEnd))
assert.Uint16(meta.SessionID).Equals(1)
assert.Byte(byte(meta.Option)).Equals(0)
meta, err = ReadMetadata(bytesReader)
assert(err, IsNil)
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusEnd))
assert(meta.SessionID, Equals, uint16(1))
assert(byte(meta.Option), Equals, byte(0))
meta, err = metaReader.Read()
assert.Error(err).IsNil()
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusEnd))
assert.Uint16(meta.SessionID).Equals(3)
assert.Byte(byte(meta.Option)).Equals(0)
meta, err = ReadMetadata(bytesReader)
assert(err, IsNil)
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusEnd))
assert(meta.SessionID, Equals, uint16(3))
assert(byte(meta.Option), Equals, byte(0))
meta, err = metaReader.Read()
assert.Error(err).IsNil()
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusKeep))
assert.Uint16(meta.SessionID).Equals(2)
assert.Byte(byte(meta.Option)).Equals(1)
meta, err = ReadMetadata(bytesReader)
assert(err, IsNil)
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusKeep))
assert(meta.SessionID, Equals, uint16(2))
assert(byte(meta.Option), Equals, byte(1))
data, err = readAll(streamReader)
assert.Error(err).IsNil()
assert.Int(len(data)).Equals(1)
assert.String(data[0].String()).Equals("y")
assert(err, IsNil)
assert(len(data), Equals, 1)
assert(data[0].String(), Equals, "y")
meta, err = metaReader.Read()
assert.Error(err).IsNil()
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusEnd))
assert.Uint16(meta.SessionID).Equals(2)
assert.Byte(byte(meta.Option)).Equals(0)
meta, err = ReadMetadata(bytesReader)
assert(err, IsNil)
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusEnd))
assert(meta.SessionID, Equals, uint16(2))
assert(byte(meta.Option), Equals, byte(0))
stream.Close()
meta, err = metaReader.Read()
assert.Error(err).IsNotNil()
meta, err = ReadMetadata(bytesReader)
assert(err, IsNotNil)
assert(meta, IsNil)
}

View File

@ -7,20 +7,9 @@ import (
"v2ray.com/core/common/serial"
)
type MetadataReader struct {
reader io.Reader
buffer []byte
}
func NewMetadataReader(reader io.Reader) *MetadataReader {
return &MetadataReader{
reader: reader,
buffer: make([]byte, 1024),
}
}
func (r *MetadataReader) Read() (*FrameMetadata, error) {
metaLen, err := serial.ReadUint16(r.reader)
// ReadMetadata reads FrameMetadata from the given reader.
func ReadMetadata(reader io.Reader) (*FrameMetadata, error) {
metaLen, err := serial.ReadUint16(reader)
if err != nil {
return nil, err
}
@ -28,17 +17,22 @@ func (r *MetadataReader) Read() (*FrameMetadata, error) {
return nil, newError("invalid metalen ", metaLen).AtWarning()
}
if _, err := io.ReadFull(r.reader, r.buffer[:metaLen]); err != nil {
b := buf.New()
defer b.Release()
if err := b.Reset(buf.ReadFullFrom(reader, int(metaLen))); err != nil {
return nil, err
}
return ReadFrameFrom(r.buffer)
return ReadFrameFrom(b.Bytes())
}
// PacketReader is an io.Reader that reads whole chunk of Mux frames every time.
type PacketReader struct {
reader io.Reader
eof bool
}
// NewPacketReader creates a new PacketReader.
func NewPacketReader(reader io.Reader) *PacketReader {
return &PacketReader{
reader: reader,
@ -46,7 +40,8 @@ func NewPacketReader(reader io.Reader) *PacketReader {
}
}
func (r *PacketReader) Read() (buf.MultiBuffer, error) {
// ReadMultiBuffer implements buf.Reader.
func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
if r.eof {
return nil, io.EOF
}
@ -70,19 +65,22 @@ func (r *PacketReader) Read() (buf.MultiBuffer, error) {
return buf.NewMultiBufferValue(b), nil
}
// StreamReader reads Mux frame as a stream.
type StreamReader struct {
reader io.Reader
reader *buf.BufferedReader
leftOver int
}
func NewStreamReader(reader io.Reader) *StreamReader {
// NewStreamReader creates a new StreamReader.
func NewStreamReader(reader *buf.BufferedReader) *StreamReader {
return &StreamReader{
reader: reader,
leftOver: -1,
}
}
func (r *StreamReader) Read() (buf.MultiBuffer, error) {
// ReadMultiBuffer implmenets buf.Reader.
func (r *StreamReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
if r.leftOver == 0 {
r.leftOver = -1
return nil, io.EOF
@ -96,24 +94,7 @@ func (r *StreamReader) Read() (buf.MultiBuffer, error) {
r.leftOver = int(size)
}
mb := buf.NewMultiBuffer()
for r.leftOver > 0 {
readLen := buf.Size
if r.leftOver < readLen {
readLen = r.leftOver
}
b := buf.New()
if err := b.AppendSupplier(func(bb []byte) (int, error) {
return r.reader.Read(bb[:readLen])
}); err != nil {
mb.Release()
return nil, err
}
r.leftOver -= b.Len()
mb.Append(b)
if b.Len() < readLen {
break
}
}
return mb, nil
mb, err := r.reader.ReadAtMost(r.leftOver)
r.leftOver -= mb.Len()
return mb, err
}

View File

@ -1,7 +1,6 @@
package mux
import (
"io"
"sync"
"v2ray.com/core/common/buf"
@ -19,7 +18,7 @@ type SessionManager struct {
func NewSessionManager() *SessionManager {
return &SessionManager{
count: 0,
sessions: make(map[uint16]*Session, 32),
sessions: make(map[uint16]*Session, 16),
}
}
@ -58,6 +57,10 @@ func (m *SessionManager) Add(s *Session) {
m.Lock()
defer m.Unlock()
if m.closed {
return
}
m.sessions[s.ID] = s
}
@ -65,6 +68,10 @@ func (m *SessionManager) Remove(id uint16) {
m.Lock()
defer m.Unlock()
if m.closed {
return
}
delete(m.sessions, id)
}
@ -111,9 +118,10 @@ func (m *SessionManager) Close() {
s.output.Close()
}
m.sessions = make(map[uint16]*Session)
m.sessions = nil
}
// Session represents a client connection in a Mux connection.
type Session struct {
input ray.InputStream
output ray.OutputStream
@ -122,13 +130,15 @@ type Session struct {
transferType protocol.TransferType
}
// Close closes all resources associated with this session.
func (s *Session) Close() {
s.output.Close()
s.input.Close()
s.parent.Remove(s.ID)
}
func (s *Session) NewReader(reader io.Reader) buf.Reader {
// NewReader creates a buf.Reader based on the transfer type of this Session.
func (s *Session) NewReader(reader *buf.BufferedReader) buf.Reader {
if s.transferType == protocol.TransferTypeStream {
return NewStreamReader(reader)
}

View File

@ -4,34 +4,36 @@ import (
"testing"
. "v2ray.com/core/app/proxyman/mux"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func TestSessionManagerAdd(t *testing.T) {
assert := assert.On(t)
assert := With(t)
m := NewSessionManager()
s := m.Allocate()
assert.Uint16(s.ID).Equals(1)
assert(s.ID, Equals, uint16(1))
assert(m.Size(), Equals, 1)
s = m.Allocate()
assert.Uint16(s.ID).Equals(2)
assert(s.ID, Equals, uint16(2))
assert(m.Size(), Equals, 2)
s = &Session{
ID: 4,
}
m.Add(s)
assert.Uint16(s.ID).Equals(4)
assert(s.ID, Equals, uint16(4))
}
func TestSessionManagerClose(t *testing.T) {
assert := assert.On(t)
assert := With(t)
m := NewSessionManager()
s := m.Allocate()
assert.Bool(m.CloseIfNoSession()).IsFalse()
assert(m.CloseIfNoSession(), IsFalse)
m.Remove(s.ID)
assert.Bool(m.CloseIfNoSession()).IsTrue()
assert(m.CloseIfNoSession(), IsTrue)
}

View File

@ -1,8 +1,7 @@
package mux
import (
"runtime"
"v2ray.com/core/common"
"v2ray.com/core/common/buf"
"v2ray.com/core/common/net"
"v2ray.com/core/common/protocol"
@ -10,9 +9,9 @@ import (
)
type Writer struct {
id uint16
dest net.Destination
writer buf.Writer
id uint16
followup bool
transferType protocol.TransferType
}
@ -54,51 +53,47 @@ func (w *Writer) getNextFrameMeta() FrameMetadata {
func (w *Writer) writeMetaOnly() error {
meta := w.getNextFrameMeta()
b := buf.New()
if err := b.AppendSupplier(meta.AsSupplier()); err != nil {
if err := b.Reset(meta.AsSupplier()); err != nil {
return err
}
runtime.KeepAlive(meta)
return w.writer.Write(buf.NewMultiBufferValue(b))
return w.writer.WriteMultiBuffer(buf.NewMultiBufferValue(b))
}
func (w *Writer) writeData(mb buf.MultiBuffer) error {
meta := w.getNextFrameMeta()
meta.Option.Add(OptionData)
meta.Option.Set(OptionData)
frame := buf.New()
if err := frame.AppendSupplier(meta.AsSupplier()); err != nil {
if err := frame.Reset(meta.AsSupplier()); err != nil {
return err
}
runtime.KeepAlive(meta)
if err := frame.AppendSupplier(serial.WriteUint16(uint16(mb.Len()))); err != nil {
return err
}
mb2 := buf.NewMultiBuffer()
mb2 := buf.NewMultiBufferCap(len(mb) + 1)
mb2.Append(frame)
mb2.AppendMulti(mb)
return w.writer.Write(mb2)
return w.writer.WriteMultiBuffer(mb2)
}
// Write implements buf.MultiBufferWriter.
func (w *Writer) Write(mb buf.MultiBuffer) error {
// WriteMultiBuffer implements buf.Writer.
func (w *Writer) WriteMultiBuffer(mb buf.MultiBuffer) error {
defer mb.Release()
if mb.IsEmpty() {
return w.writeMetaOnly()
}
if w.transferType == protocol.TransferTypeStream {
const chunkSize = 8 * 1024
for !mb.IsEmpty() {
slice := mb.SliceBySize(chunkSize)
if err := w.writeData(slice); err != nil {
return err
}
for !mb.IsEmpty() {
var chunk buf.MultiBuffer
if w.transferType == protocol.TransferTypeStream {
chunk = mb.SliceBySize(8 * 1024)
} else {
chunk = buf.NewMultiBufferValue(mb.SplitFirst())
}
} else {
for _, b := range mb {
if err := w.writeData(buf.NewMultiBufferValue(b)); err != nil {
return err
}
if err := w.writeData(chunk); err != nil {
return err
}
}
@ -112,8 +107,7 @@ func (w *Writer) Close() {
}
frame := buf.New()
frame.AppendSupplier(meta.AsSupplier())
runtime.KeepAlive(meta)
common.Must(frame.Reset(meta.AsSupplier()))
w.writer.Write(buf.NewMultiBufferValue(frame))
w.writer.WriteMultiBuffer(buf.NewMultiBufferValue(frame))
}

View File

@ -33,7 +33,7 @@ func NewHandler(ctx context.Context, config *proxyman.OutboundHandlerConfig) (*H
if space == nil {
return nil, newError("no space in context")
}
space.OnInitialize(func() error {
space.On(app.SpaceInitializing, func(interface{}) error {
ohm := proxyman.OutboundHandlerManagerFromSpace(space)
if ohm == nil {
return newError("no OutboundManager in space")
@ -78,6 +78,7 @@ func (h *Handler) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) {
err := h.mux.Dispatch(ctx, outboundRay)
if err != nil {
log.Trace(newError("failed to process outbound traffic").Base(err))
outboundRay.OutboundOutput().CloseError()
}
} else {
err := h.proxy.Process(ctx, outboundRay, h)
@ -122,8 +123,8 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (internet.Conn
}
var (
_ buf.MultiBufferReader = (*Connection)(nil)
_ buf.MultiBufferWriter = (*Connection)(nil)
_ buf.Reader = (*Connection)(nil)
_ buf.Writer = (*Connection)(nil)
)
type Connection struct {
@ -132,9 +133,8 @@ type Connection struct {
localAddr net.Addr
remoteAddr net.Addr
bytesReader io.Reader
reader buf.Reader
writer buf.Writer
reader *buf.BufferedReader
writer buf.Writer
}
func NewConnection(stream ray.Ray) *Connection {
@ -148,9 +148,8 @@ func NewConnection(stream ray.Ray) *Connection {
IP: []byte{0, 0, 0, 0},
Port: 0,
},
bytesReader: buf.ToBytesReader(stream.InboundOutput()),
reader: stream.InboundOutput(),
writer: stream.InboundInput(),
reader: buf.NewBufferedReader(stream.InboundOutput()),
writer: stream.InboundInput(),
}
}
@ -159,11 +158,11 @@ func (v *Connection) Read(b []byte) (int, error) {
if v.closed {
return 0, io.EOF
}
return v.bytesReader.Read(b)
return v.reader.Read(b)
}
func (v *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) {
return v.reader.Read()
return v.reader.ReadMultiBuffer()
}
// Write implements net.Conn.Write().
@ -171,14 +170,19 @@ func (v *Connection) Write(b []byte) (int, error) {
if v.closed {
return 0, io.ErrClosedPipe
}
return buf.ToBytesWriter(v.writer).Write(b)
l := len(b)
mb := buf.NewMultiBufferCap(l/buf.Size + 1)
mb.Write(b)
return l, v.writer.WriteMultiBuffer(mb)
}
func (v *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
if v.closed {
return io.ErrClosedPipe
}
return v.writer.Write(mb)
return v.writer.WriteMultiBuffer(mb)
}
// Close implements net.Conn.Close().

View File

@ -4,6 +4,8 @@ import (
"context"
"regexp"
"strings"
"sync"
"time"
"v2ray.com/core/common/net"
"v2ray.com/core/common/protocol"
@ -64,13 +66,123 @@ func (v *AnyCondition) Len() int {
return len(*v)
}
type PlainDomainMatcher string
func NewPlainDomainMatcher(pattern string) Condition {
return PlainDomainMatcher(pattern)
type timedResult struct {
timestamp time.Time
result bool
}
func (v PlainDomainMatcher) Apply(ctx context.Context) bool {
type CachableDomainMatcher struct {
sync.Mutex
matchers []domainMatcher
cache map[string]timedResult
lastScan time.Time
}
func NewCachableDomainMatcher() *CachableDomainMatcher {
return &CachableDomainMatcher{
matchers: make([]domainMatcher, 0, 64),
cache: make(map[string]timedResult, 512),
}
}
func (m *CachableDomainMatcher) Add(domain *Domain) error {
switch domain.Type {
case Domain_Plain:
m.matchers = append(m.matchers, NewPlainDomainMatcher(domain.Value))
case Domain_Regex:
rm, err := NewRegexpDomainMatcher(domain.Value)
if err != nil {
return err
}
m.matchers = append(m.matchers, rm)
case Domain_Domain:
m.matchers = append(m.matchers, NewSubDomainMatcher(domain.Value))
default:
return newError("unknown domain type: ", domain.Type).AtError()
}
return nil
}
func (m *CachableDomainMatcher) applyInternal(domain string) bool {
for _, matcher := range m.matchers {
if matcher.Apply(domain) {
return true
}
}
return false
}
type cacheResult int
const (
cacheMiss cacheResult = iota
cacheHitTrue
cacheHitFalse
)
func (m *CachableDomainMatcher) findInCache(domain string) cacheResult {
m.Lock()
defer m.Unlock()
r, f := m.cache[domain]
if !f {
return cacheMiss
}
r.timestamp = time.Now()
m.cache[domain] = r
if r.result {
return cacheHitTrue
}
return cacheHitFalse
}
func (m *CachableDomainMatcher) ApplyDomain(domain string) bool {
if len(m.matchers) < 64 {
return m.applyInternal(domain)
}
cr := m.findInCache(domain)
if cr == cacheHitTrue {
return true
}
if cr == cacheHitFalse {
return false
}
r := m.applyInternal(domain)
m.Lock()
defer m.Unlock()
m.cache[domain] = timedResult{
result: r,
timestamp: time.Now(),
}
now := time.Now()
if len(m.cache) > 256 && now.Sub(m.lastScan)/time.Second > 5 {
remove := make([]string, 0, 128)
now := time.Now()
for k, v := range m.cache {
if now.Sub(v.timestamp)/time.Second > 60 {
remove = append(remove, k)
}
}
for _, v := range remove {
delete(m.cache, v)
}
m.lastScan = now
}
return r
}
func (m *CachableDomainMatcher) Apply(ctx context.Context) bool {
dest, ok := proxy.TargetFromContext(ctx)
if !ok {
return false
@ -79,7 +191,20 @@ func (v PlainDomainMatcher) Apply(ctx context.Context) bool {
if !dest.Address.Family().IsDomain() {
return false
}
domain := dest.Address.Domain()
return m.ApplyDomain(dest.Address.Domain())
}
type domainMatcher interface {
Apply(domain string) bool
}
type PlainDomainMatcher string
func NewPlainDomainMatcher(pattern string) PlainDomainMatcher {
return PlainDomainMatcher(pattern)
}
func (v PlainDomainMatcher) Apply(domain string) bool {
return strings.Contains(domain, string(v))
}
@ -97,33 +222,17 @@ func NewRegexpDomainMatcher(pattern string) (*RegexpDomainMatcher, error) {
}, nil
}
func (v *RegexpDomainMatcher) Apply(ctx context.Context) bool {
dest, ok := proxy.TargetFromContext(ctx)
if !ok {
return false
}
if !dest.Address.Family().IsDomain() {
return false
}
domain := dest.Address.Domain()
func (v *RegexpDomainMatcher) Apply(domain string) bool {
return v.pattern.MatchString(strings.ToLower(domain))
}
type SubDomainMatcher string
func NewSubDomainMatcher(p string) Condition {
func NewSubDomainMatcher(p string) SubDomainMatcher {
return SubDomainMatcher(p)
}
func (m SubDomainMatcher) Apply(ctx context.Context) bool {
dest, ok := proxy.TargetFromContext(ctx)
if !ok {
return false
}
if !dest.Address.Family().IsDomain() {
return false
}
domain := dest.Address.Domain()
func (m SubDomainMatcher) Apply(domain string) bool {
pattern := string(m)
if !strings.HasSuffix(domain, pattern) {
return false
@ -149,8 +258,9 @@ func NewCIDRMatcher(ip []byte, mask uint32, onSource bool) (*CIDRMatcher, error)
func (v *CIDRMatcher) Apply(ctx context.Context) bool {
ips := make([]net.IP, 0, 4)
if resolveIPs, ok := proxy.ResolvedIPsFromContext(ctx); ok {
for _, rip := range resolveIPs {
if resolver, ok := proxy.ResolvedIPsFromContext(ctx); ok {
resolvedIPs := resolver.Resolve()
for _, rip := range resolvedIPs {
if !rip.Family().IsIPv6() {
continue
}
@ -192,8 +302,9 @@ func NewIPv4Matcher(ipnet *net.IPNetTable, onSource bool) *IPv4Matcher {
func (v *IPv4Matcher) Apply(ctx context.Context) bool {
ips := make([]net.IP, 0, 4)
if resolveIPs, ok := proxy.ResolvedIPsFromContext(ctx); ok {
for _, rip := range resolveIPs {
if resolver, ok := proxy.ResolvedIPsFromContext(ctx); ok {
resolvedIPs := resolver.Resolve()
for _, rip := range resolvedIPs {
if !rip.Family().IsIPv4() {
continue
}

View File

@ -2,57 +2,66 @@ package router_test
import (
"context"
"os"
"path/filepath"
"strconv"
"testing"
"time"
proto "github.com/golang/protobuf/proto"
. "v2ray.com/core/app/router"
"v2ray.com/core/common"
"v2ray.com/core/common/errors"
"v2ray.com/core/common/net"
"v2ray.com/core/common/platform"
"v2ray.com/core/common/protocol"
"v2ray.com/core/proxy"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
"v2ray.com/ext/sysio"
)
func TestSubDomainMatcher(t *testing.T) {
assert := assert.On(t)
assert := With(t)
cases := []struct {
pattern string
input context.Context
input string
output bool
}{
{
pattern: "v2ray.com",
input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("www.v2ray.com"), 80)),
input: "www.v2ray.com",
output: true,
},
{
pattern: "v2ray.com",
input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("v2ray.com"), 80)),
input: "v2ray.com",
output: true,
},
{
pattern: "v2ray.com",
input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("www.v3ray.com"), 80)),
input: "www.v3ray.com",
output: false,
},
{
pattern: "v2ray.com",
input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("2ray.com"), 80)),
input: "2ray.com",
output: false,
},
{
pattern: "v2ray.com",
input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("xv2ray.com"), 80)),
input: "xv2ray.com",
output: false,
},
}
for _, test := range cases {
matcher := NewSubDomainMatcher(test.pattern)
assert.Bool(matcher.Apply(test.input) == test.output).IsTrue()
assert(matcher.Apply(test.input) == test.output, IsTrue)
}
}
func TestRoutingRule(t *testing.T) {
assert := assert.On(t)
assert := With(t)
type ruleTest struct {
input context.Context
@ -172,10 +181,56 @@ func TestRoutingRule(t *testing.T) {
for _, test := range cases {
cond, err := test.rule.BuildCondition()
assert.Error(err).IsNil()
assert(err, IsNil)
for _, t := range test.test {
assert.Bool(cond.Apply(t.input)).Equals(t.output)
assert(cond.Apply(t.input), Equals, t.output)
}
}
}
func loadGeoSite(country string) ([]*Domain, error) {
geositeBytes, err := sysio.ReadAsset("geosite.dat")
if err != nil {
return nil, err
}
var geositeList GeoSiteList
if err := proto.Unmarshal(geositeBytes, &geositeList); err != nil {
return nil, err
}
for _, site := range geositeList.Entry {
if site.CountryCode == country {
return site.Domain, nil
}
}
return nil, errors.New("country not found: " + country)
}
func TestChinaSites(t *testing.T) {
assert := With(t)
common.Must(sysio.CopyFile(platform.GetAssetLocation("geosite.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "tools", "release", "config", "geosite.dat")))
domains, err := loadGeoSite("CN")
assert(err, IsNil)
matcher := NewCachableDomainMatcher()
for _, d := range domains {
assert(matcher.Add(d), IsNil)
}
assert(matcher.ApplyDomain("163.com"), IsTrue)
assert(matcher.ApplyDomain("163.com"), IsTrue)
assert(matcher.ApplyDomain("164.com"), IsFalse)
assert(matcher.ApplyDomain("164.com"), IsFalse)
for i := 0; i < 1024; i++ {
assert(matcher.ApplyDomain(strconv.Itoa(i)+".not-exists.com"), IsFalse)
}
time.Sleep(time.Second * 10)
for i := 0; i < 1024; i++ {
assert(matcher.ApplyDomain(strconv.Itoa(i)+".not-exists2.com"), IsFalse)
}
}

View File

@ -52,24 +52,11 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
conds := NewConditionChan()
if len(rr.Domain) > 0 {
anyCond := NewAnyCondition()
matcher := NewCachableDomainMatcher()
for _, domain := range rr.Domain {
switch domain.Type {
case Domain_Plain:
anyCond.Add(NewPlainDomainMatcher(domain.Value))
case Domain_Regex:
matcher, err := NewRegexpDomainMatcher(domain.Value)
if err != nil {
return nil, err
}
anyCond.Add(matcher)
case Domain_Domain:
anyCond.Add(NewSubDomainMatcher(domain.Value))
default:
panic("Unknown domain type.")
}
matcher.Add(domain)
}
conds.Add(anyCond)
conds.Add(matcher)
}
if len(rr.Cidr) > 0 {

View File

@ -54,23 +54,27 @@ const (
Config_UseIp Config_DomainStrategy = 1
// Resolve to IP if the domain doesn't match any rules.
Config_IpIfNonMatch Config_DomainStrategy = 2
// Resolve to IP if any rule requires IP matching.
Config_IpOnDemand Config_DomainStrategy = 3
)
var Config_DomainStrategy_name = map[int32]string{
0: "AsIs",
1: "UseIp",
2: "IpIfNonMatch",
3: "IpOnDemand",
}
var Config_DomainStrategy_value = map[string]int32{
"AsIs": 0,
"UseIp": 1,
"IpIfNonMatch": 2,
"IpOnDemand": 3,
}
func (x Config_DomainStrategy) String() string {
return proto.EnumName(Config_DomainStrategy_name, int32(x))
}
func (Config_DomainStrategy) EnumDescriptor() ([]byte, []int) { return fileDescriptor0, []int{3, 0} }
func (Config_DomainStrategy) EnumDescriptor() ([]byte, []int) { return fileDescriptor0, []int{7, 0} }
// Domain for routing decision.
type Domain struct {
@ -126,6 +130,86 @@ func (m *CIDR) GetPrefix() uint32 {
return 0
}
type GeoIP struct {
CountryCode string `protobuf:"bytes,1,opt,name=country_code,json=countryCode" json:"country_code,omitempty"`
Cidr []*CIDR `protobuf:"bytes,2,rep,name=cidr" json:"cidr,omitempty"`
}
func (m *GeoIP) Reset() { *m = GeoIP{} }
func (m *GeoIP) String() string { return proto.CompactTextString(m) }
func (*GeoIP) ProtoMessage() {}
func (*GeoIP) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
func (m *GeoIP) GetCountryCode() string {
if m != nil {
return m.CountryCode
}
return ""
}
func (m *GeoIP) GetCidr() []*CIDR {
if m != nil {
return m.Cidr
}
return nil
}
type GeoIPList struct {
Entry []*GeoIP `protobuf:"bytes,1,rep,name=entry" json:"entry,omitempty"`
}
func (m *GeoIPList) Reset() { *m = GeoIPList{} }
func (m *GeoIPList) String() string { return proto.CompactTextString(m) }
func (*GeoIPList) ProtoMessage() {}
func (*GeoIPList) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
func (m *GeoIPList) GetEntry() []*GeoIP {
if m != nil {
return m.Entry
}
return nil
}
type GeoSite struct {
CountryCode string `protobuf:"bytes,1,opt,name=country_code,json=countryCode" json:"country_code,omitempty"`
Domain []*Domain `protobuf:"bytes,2,rep,name=domain" json:"domain,omitempty"`
}
func (m *GeoSite) Reset() { *m = GeoSite{} }
func (m *GeoSite) String() string { return proto.CompactTextString(m) }
func (*GeoSite) ProtoMessage() {}
func (*GeoSite) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} }
func (m *GeoSite) GetCountryCode() string {
if m != nil {
return m.CountryCode
}
return ""
}
func (m *GeoSite) GetDomain() []*Domain {
if m != nil {
return m.Domain
}
return nil
}
type GeoSiteList struct {
Entry []*GeoSite `protobuf:"bytes,1,rep,name=entry" json:"entry,omitempty"`
}
func (m *GeoSiteList) Reset() { *m = GeoSiteList{} }
func (m *GeoSiteList) String() string { return proto.CompactTextString(m) }
func (*GeoSiteList) ProtoMessage() {}
func (*GeoSiteList) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} }
func (m *GeoSiteList) GetEntry() []*GeoSite {
if m != nil {
return m.Entry
}
return nil
}
type RoutingRule struct {
Tag string `protobuf:"bytes,1,opt,name=tag" json:"tag,omitempty"`
Domain []*Domain `protobuf:"bytes,2,rep,name=domain" json:"domain,omitempty"`
@ -140,7 +224,7 @@ type RoutingRule struct {
func (m *RoutingRule) Reset() { *m = RoutingRule{} }
func (m *RoutingRule) String() string { return proto.CompactTextString(m) }
func (*RoutingRule) ProtoMessage() {}
func (*RoutingRule) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
func (*RoutingRule) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} }
func (m *RoutingRule) GetTag() string {
if m != nil {
@ -206,7 +290,7 @@ type Config struct {
func (m *Config) Reset() { *m = Config{} }
func (m *Config) String() string { return proto.CompactTextString(m) }
func (*Config) ProtoMessage() {}
func (*Config) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
func (*Config) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} }
func (m *Config) GetDomainStrategy() Config_DomainStrategy {
if m != nil {
@ -225,6 +309,10 @@ func (m *Config) GetRule() []*RoutingRule {
func init() {
proto.RegisterType((*Domain)(nil), "v2ray.core.app.router.Domain")
proto.RegisterType((*CIDR)(nil), "v2ray.core.app.router.CIDR")
proto.RegisterType((*GeoIP)(nil), "v2ray.core.app.router.GeoIP")
proto.RegisterType((*GeoIPList)(nil), "v2ray.core.app.router.GeoIPList")
proto.RegisterType((*GeoSite)(nil), "v2ray.core.app.router.GeoSite")
proto.RegisterType((*GeoSiteList)(nil), "v2ray.core.app.router.GeoSiteList")
proto.RegisterType((*RoutingRule)(nil), "v2ray.core.app.router.RoutingRule")
proto.RegisterType((*Config)(nil), "v2ray.core.app.router.Config")
proto.RegisterEnum("v2ray.core.app.router.Domain_Type", Domain_Type_name, Domain_Type_value)
@ -234,39 +322,45 @@ func init() {
func init() { proto.RegisterFile("v2ray.com/core/app/router/config.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
// 538 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x93, 0xc1, 0x6e, 0xd4, 0x30,
0x10, 0x86, 0x49, 0x76, 0x1b, 0xba, 0x93, 0xb2, 0x44, 0x16, 0x45, 0xa1, 0xa8, 0x22, 0x8a, 0x10,
0xe4, 0x80, 0x12, 0x69, 0x11, 0x70, 0x01, 0xa1, 0xb2, 0xed, 0x61, 0x25, 0xa8, 0x2a, 0xd3, 0x72,
0xe0, 0x12, 0xb9, 0x59, 0x37, 0x58, 0x24, 0xb6, 0xe5, 0x38, 0xa5, 0x7b, 0xe3, 0x05, 0x78, 0x11,
0x9e, 0x86, 0x47, 0x42, 0xb6, 0x53, 0xd1, 0xa2, 0x2e, 0xdc, 0x66, 0x9c, 0xef, 0x9f, 0x19, 0x8f,
0xff, 0xc0, 0x93, 0xf3, 0x99, 0x22, 0xab, 0xbc, 0x12, 0x6d, 0x51, 0x09, 0x45, 0x0b, 0x22, 0x65,
0xa1, 0x44, 0xaf, 0xa9, 0x2a, 0x2a, 0xc1, 0xcf, 0x58, 0x9d, 0x4b, 0x25, 0xb4, 0x40, 0xdb, 0x97,
0x9c, 0xa2, 0x39, 0x91, 0x32, 0x77, 0xcc, 0xce, 0xe3, 0xbf, 0xe4, 0x95, 0x68, 0x5b, 0xc1, 0x0b,
0x4e, 0x75, 0x21, 0x85, 0xd2, 0x4e, 0xbc, 0xf3, 0x74, 0x3d, 0xc5, 0xa9, 0xfe, 0x26, 0xd4, 0x57,
0x07, 0xa6, 0xdf, 0x3d, 0x08, 0xf6, 0x45, 0x4b, 0x18, 0x47, 0x2f, 0x61, 0xac, 0x57, 0x92, 0xc6,
0x5e, 0xe2, 0x65, 0xd3, 0x59, 0x9a, 0xdf, 0xd8, 0x3f, 0x77, 0x70, 0x7e, 0xbc, 0x92, 0x14, 0x5b,
0x1e, 0xdd, 0x83, 0x8d, 0x73, 0xd2, 0xf4, 0x34, 0xf6, 0x13, 0x2f, 0x9b, 0x60, 0x97, 0xa4, 0x19,
0x8c, 0x0d, 0x83, 0x26, 0xb0, 0x71, 0xd4, 0x10, 0xc6, 0xa3, 0x5b, 0x26, 0xc4, 0xb4, 0xa6, 0x17,
0x91, 0x87, 0xe0, 0xb2, 0x6b, 0xe4, 0xa7, 0x39, 0x8c, 0xe7, 0x8b, 0x7d, 0x8c, 0xa6, 0xe0, 0x33,
0x69, 0xbb, 0x6f, 0x61, 0x9f, 0x49, 0x74, 0x1f, 0x02, 0xa9, 0xe8, 0x19, 0xbb, 0xb0, 0x85, 0xef,
0xe0, 0x21, 0x4b, 0x7f, 0x8c, 0x20, 0xc4, 0xa2, 0xd7, 0x8c, 0xd7, 0xb8, 0x6f, 0x28, 0x8a, 0x60,
0xa4, 0x49, 0x6d, 0x85, 0x13, 0x6c, 0x42, 0xf4, 0x02, 0x82, 0xa5, 0xad, 0x1e, 0xfb, 0xc9, 0x28,
0x0b, 0x67, 0xbb, 0xff, 0xbc, 0x0b, 0x1e, 0x60, 0x54, 0xc0, 0xb8, 0x62, 0x4b, 0x15, 0x8f, 0xac,
0xe8, 0xe1, 0x1a, 0x91, 0x99, 0x15, 0x5b, 0x10, 0xbd, 0x05, 0x30, 0x3b, 0x2f, 0x15, 0xe1, 0x35,
0x8d, 0xc7, 0x89, 0x97, 0x85, 0xb3, 0xe4, 0xaa, 0xcc, 0xad, 0x3d, 0xe7, 0x54, 0xe7, 0x47, 0x42,
0x69, 0x6c, 0x38, 0x3c, 0x91, 0x97, 0x21, 0x3a, 0x80, 0xad, 0xe1, 0x39, 0xca, 0x86, 0x75, 0x3a,
0xde, 0xb0, 0x25, 0xd2, 0x35, 0x25, 0x0e, 0x1d, 0xfa, 0x9e, 0x75, 0x1a, 0x87, 0xfc, 0x4f, 0x82,
0x5e, 0x43, 0xd8, 0x89, 0x5e, 0x55, 0xb4, 0xb4, 0xf3, 0x07, 0xff, 0x9f, 0x1f, 0x1c, 0x3f, 0x37,
0xb7, 0xd8, 0x05, 0xe8, 0x3b, 0xaa, 0x4a, 0xda, 0x12, 0xd6, 0xc4, 0xb7, 0x93, 0x51, 0x36, 0xc1,
0x13, 0x73, 0x72, 0x60, 0x0e, 0xd0, 0x23, 0x08, 0x19, 0x3f, 0x15, 0x3d, 0x5f, 0x96, 0x66, 0xcd,
0x9b, 0xf6, 0x3b, 0x0c, 0x47, 0xc7, 0xa4, 0x4e, 0x7f, 0x79, 0x10, 0xcc, 0xad, 0x73, 0xd1, 0x09,
0xdc, 0x75, 0xbb, 0x2c, 0x3b, 0xad, 0x88, 0xa6, 0xf5, 0x6a, 0x70, 0xd3, 0xb3, 0x75, 0xc3, 0x38,
0xc7, 0xbb, 0x87, 0xf8, 0x38, 0x68, 0xf0, 0x74, 0x79, 0x2d, 0x37, 0xce, 0x54, 0x7d, 0x43, 0x87,
0xd7, 0x5c, 0xe7, 0xcc, 0x2b, 0x9e, 0xc0, 0x96, 0x4f, 0x5f, 0xc1, 0xf4, 0x7a, 0x65, 0xb4, 0x09,
0xe3, 0xbd, 0x6e, 0xd1, 0x39, 0x33, 0x9e, 0x74, 0x74, 0x21, 0x23, 0x0f, 0x45, 0xb0, 0xb5, 0x90,
0x8b, 0xb3, 0x43, 0xc1, 0x3f, 0x10, 0x5d, 0x7d, 0x89, 0xfc, 0x77, 0x6f, 0xe0, 0x41, 0x25, 0xda,
0x9b, 0xfb, 0x1c, 0x79, 0x9f, 0x03, 0x17, 0xfd, 0xf4, 0xb7, 0x3f, 0xcd, 0x30, 0x59, 0xe5, 0x73,
0x43, 0xec, 0x49, 0x69, 0x47, 0xa0, 0xea, 0x34, 0xb0, 0xff, 0xd6, 0xf3, 0xdf, 0x01, 0x00, 0x00,
0xff, 0xff, 0xa7, 0x6a, 0x97, 0x93, 0xeb, 0x03, 0x00, 0x00,
// 640 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x94, 0xcd, 0x6e, 0xd4, 0x3a,
0x14, 0xc7, 0x6f, 0xe6, 0xab, 0x9d, 0x93, 0xb9, 0x73, 0x23, 0xeb, 0x16, 0x0d, 0x85, 0xc2, 0x10,
0x21, 0x98, 0x05, 0x4a, 0xa4, 0xe1, 0x63, 0x05, 0xaa, 0xca, 0xb4, 0xaa, 0x22, 0x41, 0x19, 0xb9,
0x2d, 0x0b, 0x58, 0x44, 0x6e, 0xe2, 0x86, 0x88, 0x89, 0x6d, 0x39, 0x4e, 0xe9, 0xec, 0x78, 0x01,
0x5e, 0x84, 0xa7, 0xe2, 0x51, 0x90, 0xed, 0x0c, 0xb4, 0xa8, 0x81, 0x8a, 0x9d, 0xed, 0xfc, 0xfe,
0xe7, 0xfc, 0x73, 0x7c, 0x8e, 0xe1, 0xc1, 0xd9, 0x54, 0x92, 0x65, 0x90, 0xf0, 0x22, 0x4c, 0xb8,
0xa4, 0x21, 0x11, 0x22, 0x94, 0xbc, 0x52, 0x54, 0x86, 0x09, 0x67, 0xa7, 0x79, 0x16, 0x08, 0xc9,
0x15, 0x47, 0x1b, 0x2b, 0x4e, 0xd2, 0x80, 0x08, 0x11, 0x58, 0x66, 0xf3, 0xfe, 0x2f, 0xf2, 0x84,
0x17, 0x05, 0x67, 0x21, 0xa3, 0x2a, 0x14, 0x5c, 0x2a, 0x2b, 0xde, 0x7c, 0xd8, 0x4c, 0x31, 0xaa,
0x3e, 0x71, 0xf9, 0xd1, 0x82, 0xfe, 0x67, 0x07, 0x7a, 0xbb, 0xbc, 0x20, 0x39, 0x43, 0xcf, 0xa0,
0xa3, 0x96, 0x82, 0x8e, 0x9c, 0xb1, 0x33, 0x19, 0x4e, 0xfd, 0xe0, 0xca, 0xfc, 0x81, 0x85, 0x83,
0xa3, 0xa5, 0xa0, 0xd8, 0xf0, 0xe8, 0x7f, 0xe8, 0x9e, 0x91, 0x45, 0x45, 0x47, 0xad, 0xb1, 0x33,
0xe9, 0x63, 0xbb, 0xf1, 0x27, 0xd0, 0xd1, 0x0c, 0xea, 0x43, 0x77, 0xbe, 0x20, 0x39, 0xf3, 0xfe,
0xd1, 0x4b, 0x4c, 0x33, 0x7a, 0xee, 0x39, 0x08, 0x56, 0x59, 0xbd, 0x96, 0x1f, 0x40, 0x67, 0x16,
0xed, 0x62, 0x34, 0x84, 0x56, 0x2e, 0x4c, 0xf6, 0x01, 0x6e, 0xe5, 0x02, 0xdd, 0x80, 0x9e, 0x90,
0xf4, 0x34, 0x3f, 0x37, 0x81, 0xff, 0xc5, 0xf5, 0xce, 0x7f, 0x0f, 0xdd, 0x7d, 0xca, 0xa3, 0x39,
0xba, 0x07, 0x83, 0x84, 0x57, 0x4c, 0xc9, 0x65, 0x9c, 0xf0, 0xd4, 0x1a, 0xef, 0x63, 0xb7, 0x3e,
0x9b, 0xf1, 0x94, 0xa2, 0x10, 0x3a, 0x49, 0x9e, 0xca, 0x51, 0x6b, 0xdc, 0x9e, 0xb8, 0xd3, 0x5b,
0x0d, 0xff, 0xa4, 0xd3, 0x63, 0x03, 0xfa, 0xdb, 0xd0, 0x37, 0xc1, 0x5f, 0xe5, 0xa5, 0x42, 0x53,
0xe8, 0x52, 0x1d, 0x6a, 0xe4, 0x18, 0xf9, 0xed, 0x06, 0xb9, 0x11, 0x60, 0x8b, 0xfa, 0x09, 0xac,
0xed, 0x53, 0x7e, 0x98, 0x2b, 0x7a, 0x1d, 0x7f, 0x4f, 0xa1, 0x97, 0x9a, 0x3a, 0xd4, 0x0e, 0xb7,
0x7e, 0x5b, 0x75, 0x5c, 0xc3, 0xfe, 0x0c, 0xdc, 0x3a, 0x89, 0xf1, 0xf9, 0xe4, 0xb2, 0xcf, 0x3b,
0xcd, 0x3e, 0xb5, 0x64, 0xe5, 0xf4, 0x4b, 0x1b, 0x5c, 0xcc, 0x2b, 0x95, 0xb3, 0x0c, 0x57, 0x0b,
0x8a, 0x3c, 0x68, 0x2b, 0x92, 0xd5, 0x2e, 0xf5, 0xf2, 0x2f, 0xdd, 0xfd, 0x28, 0x7a, 0xfb, 0x9a,
0x45, 0x47, 0xdb, 0x00, 0xba, 0x77, 0x63, 0x49, 0x58, 0x46, 0x47, 0x9d, 0xb1, 0x33, 0x71, 0xa7,
0xe3, 0x8b, 0x32, 0xdb, 0xbe, 0x01, 0xa3, 0x2a, 0x98, 0x73, 0xa9, 0xb0, 0xe6, 0x70, 0x5f, 0xac,
0x96, 0x68, 0x0f, 0x06, 0x75, 0x5b, 0xc7, 0x8b, 0xbc, 0x54, 0xa3, 0xae, 0x09, 0xe1, 0x37, 0x84,
0x38, 0xb0, 0xa8, 0x2e, 0x1d, 0x76, 0xd9, 0xcf, 0x0d, 0x7a, 0x0e, 0x6e, 0xc9, 0x2b, 0x99, 0xd0,
0xd8, 0xf8, 0xef, 0xfd, 0xd9, 0x3f, 0x58, 0x7e, 0xa6, 0xff, 0x62, 0x0b, 0xa0, 0x2a, 0xa9, 0x8c,
0x69, 0x41, 0xf2, 0xc5, 0x68, 0x6d, 0xdc, 0x9e, 0xf4, 0x71, 0x5f, 0x9f, 0xec, 0xe9, 0x03, 0x74,
0x17, 0xdc, 0x9c, 0x9d, 0xf0, 0x8a, 0xa5, 0xb1, 0x2e, 0xf3, 0xba, 0xf9, 0x0e, 0xf5, 0xd1, 0x11,
0xc9, 0xfc, 0x6f, 0x0e, 0xf4, 0x66, 0xe6, 0x05, 0x40, 0xc7, 0xf0, 0x9f, 0xad, 0x65, 0x5c, 0x2a,
0x49, 0x14, 0xcd, 0x96, 0xf5, 0x54, 0x3e, 0x6a, 0x32, 0x63, 0x5f, 0x0e, 0x7b, 0x11, 0x87, 0xb5,
0x06, 0x0f, 0xd3, 0x4b, 0x7b, 0x3d, 0xe1, 0xb2, 0x5a, 0xd0, 0xfa, 0x36, 0x9b, 0x26, 0xfc, 0x42,
0x4f, 0x60, 0xc3, 0xfb, 0xfb, 0x30, 0xbc, 0x1c, 0x19, 0xad, 0x43, 0x67, 0xa7, 0x8c, 0x4a, 0x3b,
0xd4, 0xc7, 0x25, 0x8d, 0x84, 0xe7, 0x20, 0x0f, 0x06, 0x91, 0x88, 0x4e, 0x0f, 0x38, 0x7b, 0x4d,
0x54, 0xf2, 0xc1, 0x6b, 0xa1, 0x21, 0x40, 0x24, 0xde, 0xb0, 0x5d, 0x5a, 0x10, 0x96, 0x7a, 0xed,
0x97, 0x2f, 0xe0, 0x66, 0xc2, 0x8b, 0xab, 0xf3, 0xce, 0x9d, 0x77, 0x3d, 0xbb, 0xfa, 0xda, 0xda,
0x78, 0x3b, 0xc5, 0x64, 0x19, 0xcc, 0x34, 0xb1, 0x23, 0x84, 0xb1, 0x44, 0xe5, 0x49, 0xcf, 0xbc,
0x59, 0x8f, 0xbf, 0x07, 0x00, 0x00, 0xff, 0xff, 0x53, 0x7c, 0xa8, 0x94, 0x43, 0x05, 0x00, 0x00,
}

View File

@ -37,6 +37,24 @@ message CIDR {
uint32 prefix = 2;
}
message GeoIP {
string country_code = 1;
repeated CIDR cidr = 2;
}
message GeoIPList {
repeated GeoIP entry = 1;
}
message GeoSite {
string country_code = 1;
repeated Domain domain = 2;
}
message GeoSiteList{
repeated GeoSite entry = 1;
}
message RoutingRule {
string tag = 1;
repeated Domain domain = 2;
@ -58,6 +76,9 @@ message Config {
// Resolve to IP if the domain doesn't match any rules.
IpIfNonMatch = 2;
// Resolve to IP if any rule requires IP matching.
IpOnDemand = 3;
}
DomainStrategy domain_strategy = 1;
repeated RoutingRule rule = 2;

View File

@ -33,7 +33,7 @@ func NewRouter(ctx context.Context, config *Config) (*Router, error) {
rules: make([]Rule, len(config.Rule)),
}
space.OnInitialize(func() error {
space.On(app.SpaceInitializing, func(interface{}) error {
for idx, rule := range config.Rule {
r.rules[idx].Tag = rule.Tag
cond, err := rule.BuildCondition()
@ -52,19 +52,42 @@ func NewRouter(ctx context.Context, config *Config) (*Router, error) {
return r, nil
}
func (r *Router) resolveIP(dest net.Destination) []net.Address {
ips := r.dnsServer.Get(dest.Address.Domain())
type ipResolver struct {
ip []net.Address
domain string
resolved bool
dnsServer dns.Server
}
func (r *ipResolver) Resolve() []net.Address {
if r.resolved {
return r.ip
}
log.Trace(newError("looking for IP for domain: ", r.domain))
r.resolved = true
ips := r.dnsServer.Get(r.domain)
if len(ips) == 0 {
return nil
}
dests := make([]net.Address, len(ips))
for idx, ip := range ips {
dests[idx] = net.IPAddress(ip)
r.ip = make([]net.Address, len(ips))
for i, ip := range ips {
r.ip[i] = net.IPAddress(ip)
}
return dests
return r.ip
}
func (r *Router) TakeDetour(ctx context.Context) (string, error) {
resolver := &ipResolver{
dnsServer: r.dnsServer,
}
if r.domainStrategy == Config_IpOnDemand {
if dest, ok := proxy.TargetFromContext(ctx); ok && dest.Address.Family().IsDomain() {
resolver.domain = dest.Address.Domain()
ctx = proxy.ContextWithResolveIPs(ctx, resolver)
}
}
for _, rule := range r.rules {
if rule.Apply(ctx) {
return rule.Tag, nil
@ -77,10 +100,10 @@ func (r *Router) TakeDetour(ctx context.Context) (string, error) {
}
if r.domainStrategy == Config_IpIfNonMatch && dest.Address.Family().IsDomain() {
log.Trace(newError("looking up IP for ", dest))
ipDests := r.resolveIP(dest)
if ipDests != nil {
ctx = proxy.ContextWithResolveIPs(ctx, ipDests)
resolver.domain = dest.Address.Domain()
ips := resolver.Resolve()
if len(ips) > 0 {
ctx = proxy.ContextWithResolveIPs(ctx, resolver)
for _, rule := range r.rules {
if rule.Apply(ctx) {
return rule.Tag, nil

View File

@ -14,11 +14,11 @@ import (
. "v2ray.com/core/app/router"
"v2ray.com/core/common/net"
"v2ray.com/core/proxy"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func TestSimpleRouter(t *testing.T) {
assert := assert.On(t)
assert := With(t)
config := &Config{
Rule: []*RoutingRule{
@ -33,16 +33,16 @@ func TestSimpleRouter(t *testing.T) {
space := app.NewSpace()
ctx := app.ContextWithSpace(context.Background(), space)
assert.Error(app.AddApplicationToSpace(ctx, new(dns.Config))).IsNil()
assert.Error(app.AddApplicationToSpace(ctx, new(dispatcher.Config))).IsNil()
assert.Error(app.AddApplicationToSpace(ctx, new(proxyman.OutboundConfig))).IsNil()
assert.Error(app.AddApplicationToSpace(ctx, config)).IsNil()
assert.Error(space.Initialize()).IsNil()
assert(app.AddApplicationToSpace(ctx, new(dns.Config)), IsNil)
assert(app.AddApplicationToSpace(ctx, new(dispatcher.Config)), IsNil)
assert(app.AddApplicationToSpace(ctx, new(proxyman.OutboundConfig)), IsNil)
assert(app.AddApplicationToSpace(ctx, config), IsNil)
assert(space.Initialize(), IsNil)
r := FromSpace(space)
ctx = proxy.ContextWithTarget(ctx, net.TCPDestination(net.DomainAddress("v2ray.com"), 80))
tag, err := r.TakeDetour(ctx)
assert.Error(err).IsNil()
assert.String(tag).Equals("test")
assert(err, IsNil)
assert(tag, Equals, "test")
}

View File

@ -5,6 +5,7 @@ import (
"reflect"
"v2ray.com/core/common"
"v2ray.com/core/common/event"
)
type Application interface {
@ -13,8 +14,6 @@ type Application interface {
Close()
}
type InitializationCallback func() error
func CreateAppFromConfig(ctx context.Context, config interface{}) (Application, error) {
application, err := common.CreateObject(ctx, config)
if err != nil {
@ -29,46 +28,47 @@ func CreateAppFromConfig(ctx context.Context, config interface{}) (Application,
}
// A Space contains all apps that may be available in a V2Ray runtime.
// Caller must check the availability of an app by calling HasXXX before getting its instance.
type Space interface {
event.Registry
GetApplication(appInterface interface{}) Application
AddApplication(application Application) error
Initialize() error
OnInitialize(InitializationCallback)
Start() error
Close()
}
const (
// SpaceInitializing is an event to be fired when Space is being initialized.
SpaceInitializing event.Event = iota
)
type spaceImpl struct {
initialized bool
event.Listener
cache map[reflect.Type]Application
appInit []InitializationCallback
initialized bool
}
// NewSpace creates a new Space.
func NewSpace() Space {
return &spaceImpl{
cache: make(map[reflect.Type]Application),
appInit: make([]InitializationCallback, 0, 32),
cache: make(map[reflect.Type]Application),
}
}
func (s *spaceImpl) OnInitialize(f InitializationCallback) {
if s.initialized {
f()
} else {
s.appInit = append(s.appInit, f)
func (s *spaceImpl) On(e event.Event, h event.Handler) {
if e == SpaceInitializing && s.initialized {
_ = h(nil) // Ignore error
return
}
s.Listener.On(e, h)
}
func (s *spaceImpl) Initialize() error {
for _, f := range s.appInit {
if err := f(); err != nil {
return err
}
if s.initialized {
return nil
}
s.appInit = nil
s.initialized = true
return nil
return s.Fire(SpaceInitializing, nil)
}
func (s *spaceImpl) GetApplication(appInterface interface{}) Application {

View File

@ -1,52 +0,0 @@
package vpndialer
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
type Config struct {
Address string `protobuf:"bytes,1,opt,name=address" json:"address,omitempty"`
}
func (m *Config) Reset() { *m = Config{} }
func (m *Config) String() string { return proto.CompactTextString(m) }
func (*Config) ProtoMessage() {}
func (*Config) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
func (m *Config) GetAddress() string {
if m != nil {
return m.Address
}
return ""
}
func init() {
proto.RegisterType((*Config)(nil), "v2ray.core.app.vpndialer.Config")
}
func init() { proto.RegisterFile("v2ray.com/core/app/vpndialer/config.proto", fileDescriptor0) }
var fileDescriptor0 = []byte{
// 150 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xd2, 0x2c, 0x33, 0x2a, 0x4a,
0xac, 0xd4, 0x4b, 0xce, 0xcf, 0xd5, 0x4f, 0xce, 0x2f, 0x4a, 0xd5, 0x4f, 0x2c, 0x28, 0xd0, 0x2f,
0x2b, 0xc8, 0x4b, 0xc9, 0x4c, 0xcc, 0x49, 0x2d, 0xd2, 0x4f, 0xce, 0xcf, 0x4b, 0xcb, 0x4c, 0xd7,
0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x92, 0x80, 0x29, 0x2d, 0x4a, 0xd5, 0x4b, 0x2c, 0x28, 0xd0,
0x83, 0x2b, 0x53, 0x52, 0xe2, 0x62, 0x73, 0x06, 0xab, 0x14, 0x92, 0xe0, 0x62, 0x4f, 0x4c, 0x49,
0x29, 0x4a, 0x2d, 0x2e, 0x96, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x0c, 0x82, 0x71, 0x9d, 0xdc, 0xb8,
0x64, 0x92, 0xf3, 0x73, 0xf5, 0x70, 0x99, 0x11, 0xc0, 0x18, 0xc5, 0x09, 0xe7, 0xac, 0x62, 0x92,
0x08, 0x33, 0x0a, 0x4a, 0xac, 0xd4, 0x73, 0x06, 0xa9, 0x73, 0x2c, 0x28, 0xd0, 0x0b, 0x2b, 0xc8,
0x73, 0x01, 0x4b, 0x25, 0xb1, 0x81, 0x1d, 0x63, 0x0c, 0x08, 0x00, 0x00, 0xff, 0xff, 0x97, 0xfc,
0x09, 0x70, 0xb9, 0x00, 0x00, 0x00,
}

View File

@ -1,11 +0,0 @@
syntax = "proto3";
package v2ray.core.app.vpndialer;
option csharp_namespace = "V2Ray.Core.App.VpnDialer";
option go_package = "vpndialer";
option java_package = "com.v2ray.core.app.vpndialer";
option java_multiple_files = true;
message Config {
string address = 1;
}

View File

@ -1,7 +0,0 @@
package unix
import "v2ray.com/core/common/errors"
func newError(values ...interface{}) *errors.Error {
return errors.New(values...).Path("App", "VPNDialer", "Unix")
}

View File

@ -1,217 +0,0 @@
package unix
import (
"context"
"os"
"sync"
"golang.org/x/sys/unix"
"v2ray.com/core/app/vpndialer"
"v2ray.com/core/common"
"v2ray.com/core/common/net"
"v2ray.com/core/common/serial"
"v2ray.com/core/transport/internet"
)
//go:generate go run $GOPATH/src/v2ray.com/core/tools/generrorgen/main.go -pkg unix -path App,VPNDialer,Unix
type status int
const (
statusNew status = iota
statusOK
statusFail
)
type fdStatus struct {
status status
fd int
callback chan<- error
}
type protector struct {
sync.Mutex
address string
conn *net.UnixConn
status chan fdStatus
}
func readFrom(conn *net.UnixConn, schan chan<- fdStatus) {
var payload [6]byte
for {
_, err := conn.Read(payload[:])
if err != nil {
break
}
fd := serial.BytesToInt(payload[1:5])
s := status(payload[5])
schan <- fdStatus{
fd: fd,
status: s,
}
}
}
func (m *protector) dial() (*net.UnixConn, error) {
m.Lock()
defer m.Unlock()
if m.conn != nil {
return m.conn, nil
}
conn, err := net.DialUnix("unix", nil, &net.UnixAddr{
Name: m.address,
Net: "unix",
})
if err != nil {
return nil, err
}
m.conn = conn
m.status = make(chan fdStatus, 32)
go readFrom(conn, m.status)
go m.monitor(conn)
return conn, nil
}
func (m *protector) close() {
m.Lock()
defer m.Unlock()
if m.conn == nil {
return
}
m.conn.Close()
m.conn = nil
}
func (m *protector) monitor(c *net.UnixConn) {
pendingFd := make(map[int]chan<- error, 32)
for s := range m.status {
switch s.status {
case statusNew:
pendingFd[s.fd] = s.callback
case statusOK:
if c, f := pendingFd[s.fd]; f {
close(c)
delete(pendingFd, s.fd)
}
case statusFail:
if c, f := pendingFd[s.fd]; f {
c <- newError("failed to protect fd")
close(c)
delete(pendingFd, s.fd)
}
}
}
}
func (m *protector) protect(fd int) error {
conn, err := m.dial()
if err != nil {
return err
}
var payload [6]byte
serial.IntToBytes(fd, payload[1:1])
payload[5] = byte(statusNew)
if _, err := conn.Write(payload[:]); err != nil {
return err
}
wait := make(chan error)
m.status <- fdStatus{
status: statusNew,
fd: fd,
callback: wait,
}
return <-wait
}
type App struct {
protector *protector
dialer *Dialer
}
func NewApp(ctx context.Context, config *vpndialer.Config) (*App, error) {
a := &App{
dialer: &Dialer{},
protector: &protector{
address: config.Address,
},
}
a.dialer.protect = a.protector.protect
return a, nil
}
func (*App) Interface() interface{} {
return (*App)(nil)
}
func (a *App) Start() error {
internet.UseAlternativeSystemDialer(a.dialer)
return nil
}
func (a *App) Close() {
internet.UseAlternativeSystemDialer(nil)
}
type Dialer struct {
protect func(fd int) error
}
func socket(dest net.Destination) (int, error) {
switch dest.Network {
case net.Network_TCP:
return unix.Socket(unix.AF_INET6, unix.SOCK_STREAM, unix.IPPROTO_TCP)
case net.Network_UDP:
return unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
default:
return 0, newError("unknown network ", dest.Network)
}
}
func getIP(addr net.Address) (net.IP, error) {
if addr.Family().Either(net.AddressFamilyIPv4, net.AddressFamilyIPv6) {
return addr.IP(), nil
}
ips, err := net.LookupIP(addr.Domain())
if err != nil {
return nil, err
}
return ips[0], nil
}
func (d *Dialer) Dial(ctx context.Context, source net.Address, dest net.Destination) (net.Conn, error) {
fd, err := socket(dest)
if err != nil {
return nil, err
}
if err := d.protect(fd); err != nil {
return nil, err
}
ip, err := getIP(dest.Address)
if err != nil {
return nil, err
}
addr := &unix.SockaddrInet6{
Port: int(dest.Port),
ZoneId: 0,
}
copy(addr.Addr[:], ip.To16())
if err := unix.Connect(fd, addr); err != nil {
return nil, err
}
file := os.NewFile(uintptr(fd), "Socket")
return net.FileConn(file)
}
func init() {
common.Must(common.RegisterConfig((*vpndialer.Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
return NewApp(ctx, config.(*vpndialer.Config))
}))
}

View File

@ -1 +0,0 @@
package vpndialer

21
common/bitmask/byte.go Normal file
View File

@ -0,0 +1,21 @@
package bitmask
// Byte is a bitmask in byte.
type Byte byte
// Has returns true if this bitmask contains another bitmask.
func (b Byte) Has(bb Byte) bool {
return (b & bb) != 0
}
func (b *Byte) Set(bb Byte) {
*b |= bb
}
func (b *Byte) Clear(bb Byte) {
*b &= ^bb
}
func (b *Byte) Toggle(bb Byte) {
*b ^= bb
}

View File

@ -0,0 +1,27 @@
package bitmask_test
import (
"testing"
. "v2ray.com/core/common/bitmask"
. "v2ray.com/ext/assert"
)
func TestBitmaskByte(t *testing.T) {
assert := With(t)
b := Byte(0)
b.Set(Byte(1))
assert(b.Has(1), IsTrue)
b.Set(Byte(2))
assert(b.Has(2), IsTrue)
assert(b.Has(1), IsTrue)
b.Clear(Byte(1))
assert(b.Has(2), IsTrue)
assert(b.Has(1), IsFalse)
b.Toggle(Byte(2))
assert(b.Has(2), IsFalse)
}

View File

@ -1,10 +1,7 @@
package buf
import (
"runtime"
"sync"
"v2ray.com/core/common/platform"
)
// Pool provides functionality to generate and recycle buffers on demand.
@ -45,79 +42,11 @@ func (p *SyncPool) Free(buffer *Buffer) {
}
}
// BufferPool is a Pool that utilizes an internal cache.
type BufferPool struct {
chain chan []byte
sub Pool
}
// NewBufferPool creates a new BufferPool with given buffer size, and internal cache size.
func NewBufferPool(bufferSize, poolSize uint32) *BufferPool {
pool := &BufferPool{
chain: make(chan []byte, poolSize),
sub: NewSyncPool(bufferSize),
}
for i := uint32(0); i < poolSize; i++ {
pool.chain <- make([]byte, bufferSize)
}
return pool
}
// Allocate implements Pool.Allocate().
func (p *BufferPool) Allocate() *Buffer {
select {
case b := <-p.chain:
return &Buffer{
v: b,
pool: p,
}
default:
return p.sub.Allocate()
}
}
// Free implements Pool.Free().
func (p *BufferPool) Free(buffer *Buffer) {
if buffer.v == nil {
return
}
select {
case p.chain <- buffer.v:
default:
p.sub.Free(buffer)
}
}
const (
// Size of a regular buffer.
Size = 2 * 1024
poolSizeEnvKey = "v2ray.buffer.size"
)
var (
mediumPool Pool
mediumPool Pool = NewSyncPool(Size)
)
func getDefaultPoolSize() int {
switch runtime.GOARCH {
case "amd64", "386":
return 20
default:
return 5
}
}
func init() {
f := platform.EnvFlag{
Name: poolSizeEnvKey,
AltName: platform.NormalizeEnvName(poolSizeEnvKey),
}
size := f.GetValueAsInt(getDefaultPoolSize())
if size > 0 {
totalByteSize := uint32(size) * 1024 * 1024
mediumPool = NewBufferPool(Size, totalByteSize/Size)
} else {
mediumPool = NewSyncPool(Size)
}
}

View File

@ -6,64 +6,64 @@ import (
. "v2ray.com/core/common/buf"
"v2ray.com/core/common/serial"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func TestBufferClear(t *testing.T) {
assert := assert.On(t)
assert := With(t)
buffer := New()
defer buffer.Release()
payload := "Bytes"
buffer.Append([]byte(payload))
assert.Int(buffer.Len()).Equals(len(payload))
assert(buffer.Len(), Equals, len(payload))
buffer.Clear()
assert.Int(buffer.Len()).Equals(0)
assert(buffer.Len(), Equals, 0)
}
func TestBufferIsEmpty(t *testing.T) {
assert := assert.On(t)
assert := With(t)
buffer := New()
defer buffer.Release()
assert.Bool(buffer.IsEmpty()).IsTrue()
assert(buffer.IsEmpty(), IsTrue)
}
func TestBufferString(t *testing.T) {
assert := assert.On(t)
assert := With(t)
buffer := New()
defer buffer.Release()
assert.Error(buffer.AppendSupplier(serial.WriteString("Test String"))).IsNil()
assert.String(buffer.String()).Equals("Test String")
assert(buffer.AppendSupplier(serial.WriteString("Test String")), IsNil)
assert(buffer.String(), Equals, "Test String")
}
func TestBufferWrite(t *testing.T) {
assert := assert.On(t)
assert := With(t)
buffer := NewLocal(8)
nBytes, err := buffer.Write([]byte("abcd"))
assert.Error(err).IsNil()
assert.Int(nBytes).Equals(4)
assert(err, IsNil)
assert(nBytes, Equals, 4)
nBytes, err = buffer.Write([]byte("abcde"))
assert.Error(err).IsNil()
assert.Int(nBytes).Equals(4)
assert.String(buffer.String()).Equals("abcdabcd")
assert(err, IsNil)
assert(nBytes, Equals, 4)
assert(buffer.String(), Equals, "abcdabcd")
}
func TestSyncPool(t *testing.T) {
assert := assert.On(t)
assert := With(t)
p := NewSyncPool(32)
b := p.Allocate()
assert.Int(b.Len()).Equals(0)
assert(b.Len(), Equals, 0)
assert.Error(b.AppendSupplier(ReadFrom(rand.Reader))).IsNil()
assert.Int(b.Len()).Equals(32)
assert(b.AppendSupplier(ReadFrom(rand.Reader)), IsNil)
assert(b.Len(), Equals, 32)
b.Release()
}

View File

@ -1,53 +0,0 @@
package buf
import (
"io"
)
// BufferedReader is a reader with internal cache.
type BufferedReader struct {
reader io.Reader
buffer *Buffer
buffered bool
}
// NewBufferedReader creates a new BufferedReader based on an io.Reader.
func NewBufferedReader(rawReader io.Reader) *BufferedReader {
return &BufferedReader{
reader: rawReader,
buffer: NewLocal(1024),
buffered: true,
}
}
// IsBuffered returns true if the internal cache is effective.
func (v *BufferedReader) IsBuffered() bool {
return v.buffered
}
// SetBuffered is to enable or disable internal cache. If cache is disabled,
// Read() calls will be delegated to the underlying io.Reader directly.
func (v *BufferedReader) SetBuffered(cached bool) {
v.buffered = cached
}
// Read implements io.Reader.Read().
func (v *BufferedReader) Read(b []byte) (int, error) {
if !v.buffered || v.buffer == nil {
if !v.buffer.IsEmpty() {
return v.buffer.Read(b)
}
return v.reader.Read(b)
}
if v.buffer.IsEmpty() {
if err := v.buffer.Reset(ReadFrom(v.reader)); err != nil {
return 0, err
}
}
if v.buffer.IsEmpty() {
return 0, nil
}
return v.buffer.Read(b)
}

View File

@ -1,36 +0,0 @@
package buf_test
import (
"crypto/rand"
"testing"
. "v2ray.com/core/common/buf"
"v2ray.com/core/testing/assert"
)
func TestBufferedReader(t *testing.T) {
assert := assert.On(t)
content := New()
assert.Error(content.AppendSupplier(ReadFrom(rand.Reader))).IsNil()
len := content.Len()
reader := NewBufferedReader(content)
assert.Bool(reader.IsBuffered()).IsTrue()
payload := make([]byte, 16)
nBytes, err := reader.Read(payload)
assert.Int(nBytes).Equals(16)
assert.Error(err).IsNil()
len2 := content.Len()
assert.Int(len - len2).GreaterThan(16)
nBytes, err = reader.Read(payload)
assert.Int(nBytes).Equals(16)
assert.Error(err).IsNil()
assert.Int(content.Len()).Equals(len2)
}

View File

@ -1,72 +0,0 @@
package buf
import "io"
// BufferedWriter is an io.Writer with internal buffer. It writes to underlying writer when buffer is full or on demand.
// This type is not thread safe.
type BufferedWriter struct {
writer io.Writer
buffer *Buffer
buffered bool
}
// NewBufferedWriter creates a new BufferedWriter.
func NewBufferedWriter(writer io.Writer) *BufferedWriter {
return NewBufferedWriterSize(writer, 1024)
}
func NewBufferedWriterSize(writer io.Writer, size uint32) *BufferedWriter {
return &BufferedWriter{
writer: writer,
buffer: NewLocal(int(size)),
buffered: true,
}
}
// Write implements io.Writer.
func (w *BufferedWriter) Write(b []byte) (int, error) {
if !w.buffered || w.buffer == nil {
return w.writer.Write(b)
}
bytesWritten := 0
for bytesWritten < len(b) {
nBytes, err := w.buffer.Write(b[bytesWritten:])
if err != nil {
return bytesWritten, err
}
bytesWritten += nBytes
if w.buffer.IsFull() {
if err := w.Flush(); err != nil {
return bytesWritten, err
}
}
}
return bytesWritten, nil
}
// Flush writes all buffered content into underlying writer, if any.
func (w *BufferedWriter) Flush() error {
defer w.buffer.Clear()
for !w.buffer.IsEmpty() {
nBytes, err := w.writer.Write(w.buffer.Bytes())
if err != nil {
return err
}
w.buffer.SliceFrom(nBytes)
}
return nil
}
// IsBuffered returns true if this BufferedWriter holds a buffer.
func (w *BufferedWriter) IsBuffered() bool {
return w.buffered
}
// SetBuffered controls whether the BufferedWriter holds a buffer for writing. If not buffered, any write() calls into underlying writer directly.
func (w *BufferedWriter) SetBuffered(cached bool) error {
w.buffered = cached
if !cached && !w.buffer.IsEmpty() {
return w.Flush()
}
return nil
}

View File

@ -1,53 +0,0 @@
package buf_test
import (
"crypto/rand"
"testing"
. "v2ray.com/core/common/buf"
"v2ray.com/core/testing/assert"
)
func TestBufferedWriter(t *testing.T) {
assert := assert.On(t)
content := New()
writer := NewBufferedWriter(content)
assert.Bool(writer.IsBuffered()).IsTrue()
payload := make([]byte, 16)
nBytes, err := writer.Write(payload)
assert.Int(nBytes).Equals(16)
assert.Error(err).IsNil()
assert.Bool(content.IsEmpty()).IsTrue()
assert.Error(writer.SetBuffered(false)).IsNil()
assert.Int(content.Len()).Equals(16)
}
func TestBufferedWriterLargePayload(t *testing.T) {
assert := assert.On(t)
content := NewLocal(128 * 1024)
writer := NewBufferedWriter(content)
assert.Bool(writer.IsBuffered()).IsTrue()
payload := make([]byte, 64*1024)
rand.Read(payload)
nBytes, err := writer.Write(payload[:512])
assert.Int(nBytes).Equals(512)
assert.Error(err).IsNil()
assert.Bool(content.IsEmpty()).IsTrue()
nBytes, err = writer.Write(payload[512:])
assert.Error(err).IsNil()
assert.Error(writer.Flush()).IsNil()
assert.Int(nBytes).Equals(64*1024 - 512)
assert.Bytes(content.Bytes()).Equals(payload)
}

View File

@ -17,7 +17,7 @@ type copyHandler struct {
}
func (h *copyHandler) readFrom(reader Reader) (MultiBuffer, error) {
mb, err := reader.Read()
mb, err := reader.ReadMultiBuffer()
if err != nil {
for _, handler := range h.onReadError {
err = handler(err)
@ -27,7 +27,7 @@ func (h *copyHandler) readFrom(reader Reader) (MultiBuffer, error) {
}
func (h *copyHandler) writeTo(writer Writer, mb MultiBuffer) error {
err := writer.Write(mb)
err := writer.WriteMultiBuffer(mb)
if err != nil {
for _, handler := range h.onWriteError {
err = handler(err)
@ -36,8 +36,14 @@ func (h *copyHandler) writeTo(writer Writer, mb MultiBuffer) error {
return err
}
type SizeCounter struct {
Size int64
}
// CopyOption is an option for copying data.
type CopyOption func(*copyHandler)
// IgnoreReaderError is a CopyOption that ignores errors from reader. Copy will continue in such case.
func IgnoreReaderError() CopyOption {
return func(handler *copyHandler) {
handler.onReadError = append(handler.onReadError, func(err error) error {
@ -46,6 +52,7 @@ func IgnoreReaderError() CopyOption {
}
}
// IgnoreWriterError is a CopyOption that ignores errors from writer. Copy will continue in such case.
func IgnoreWriterError() CopyOption {
return func(handler *copyHandler) {
handler.onWriteError = append(handler.onWriteError, func(err error) error {
@ -54,6 +61,7 @@ func IgnoreWriterError() CopyOption {
}
}
// UpdateActivity is a CopyOption to update activity on each data copy operation.
func UpdateActivity(timer signal.ActivityUpdater) CopyOption {
return func(handler *copyHandler) {
handler.onData = append(handler.onData, func(MultiBuffer) {
@ -62,31 +70,34 @@ func UpdateActivity(timer signal.ActivityUpdater) CopyOption {
}
}
// CountSize is a CopyOption that sums the total size of data copied into the given SizeCounter.
func CountSize(sc *SizeCounter) CopyOption {
return func(handler *copyHandler) {
handler.onData = append(handler.onData, func(b MultiBuffer) {
sc.Size += int64(b.Len())
})
}
}
func copyInternal(reader Reader, writer Writer, handler *copyHandler) error {
for {
buffer, err := handler.readFrom(reader)
if err != nil {
return err
}
if !buffer.IsEmpty() {
for _, handler := range handler.onData {
handler(buffer)
}
if buffer.IsEmpty() {
buffer.Release()
continue
}
for _, handler := range handler.onData {
handler(buffer)
}
if err := handler.writeTo(writer, buffer); err != nil {
buffer.Release()
if werr := handler.writeTo(writer, buffer); werr != nil {
buffer.Release()
return werr
}
} else if err != nil {
return err
}
}
}
// Copy dumps all payload from reader to writer or stops when an error occurs.
// ActivityTimer gets updated as soon as there is a payload.
// Copy dumps all payload from reader to writer or stops when an error occurs. It returns nil when EOF.
func Copy(reader Reader, writer Writer, options ...CopyOption) error {
handler := new(copyHandler)
for _, option := range options {

View File

@ -5,22 +5,24 @@ import (
"time"
)
// Reader extends io.Reader with alloc.Buffer.
// Reader extends io.Reader with MultiBuffer.
type Reader interface {
// Read reads content from underlying reader, and put it into an alloc.Buffer.
Read() (MultiBuffer, error)
// ReadMultiBuffer reads content from underlying reader, and put it into a MultiBuffer.
ReadMultiBuffer() (MultiBuffer, error)
}
// ErrReadTimeout is an error that happens with IO timeout.
var ErrReadTimeout = newError("IO timeout")
// TimeoutReader is a reader that returns error if Read() operation takes longer than the given timeout.
type TimeoutReader interface {
ReadTimeout(time.Duration) (MultiBuffer, error)
}
// Writer extends io.Writer with alloc.Buffer.
// Writer extends io.Writer with MultiBuffer.
type Writer interface {
// Write writes an alloc.Buffer into underlying writer.
Write(MultiBuffer) error
// WriteMultiBuffer writes a MultiBuffer into underlying writer.
WriteMultiBuffer(MultiBuffer) error
}
// ReadFrom creates a Supplier to read from a given io.Reader.
@ -47,57 +49,21 @@ func ReadAtLeastFrom(reader io.Reader, size int) Supplier {
// NewReader creates a new Reader.
// The Reader instance doesn't take the ownership of reader.
func NewReader(reader io.Reader) Reader {
if mr, ok := reader.(MultiBufferReader); ok {
return &readerAdpater{
MultiBufferReader: mr,
}
if mr, ok := reader.(Reader); ok {
return mr
}
return &BytesToBufferReader{
reader: reader,
buffer: make([]byte, 32*1024),
}
}
func NewMergingReader(reader io.Reader) Reader {
return NewMergingReaderSize(reader, 32*1024)
}
func NewMergingReaderSize(reader io.Reader, size uint32) Reader {
return &BytesToBufferReader{
reader: reader,
buffer: make([]byte, size),
}
}
// ToBytesReader converts a Reaaer to io.Reader.
func ToBytesReader(stream Reader) io.Reader {
return &bufferToBytesReader{
stream: stream,
}
return NewBytesToBufferReader(reader)
}
// NewWriter creates a new Writer.
func NewWriter(writer io.Writer) Writer {
if mw, ok := writer.(MultiBufferWriter); ok {
return &writerAdapter{
writer: mw,
}
if mw, ok := writer.(Writer); ok {
return mw
}
return &BufferToBytesWriter{
writer: writer,
}
}
func NewMergingWriter(writer io.Writer) Writer {
return NewMergingWriterSize(writer, 4096)
}
func NewMergingWriterSize(writer io.Writer, size uint32) Writer {
return &mergingWriter{
writer: writer,
buffer: make([]byte, size),
Writer: writer,
}
}
@ -106,10 +72,3 @@ func NewSequentialWriter(writer io.Writer) Writer {
writer: writer,
}
}
// ToBytesWriter converts a Writer to io.Writer
func ToBytesWriter(writer Writer) io.Writer {
return &bytesToBufferWriter{
writer: writer,
}
}

View File

@ -1,21 +1,53 @@
package buf
import "net"
import (
"io"
"net"
type MultiBufferWriter interface {
WriteMultiBuffer(MultiBuffer) error
"v2ray.com/core/common"
"v2ray.com/core/common/errors"
)
// ReadAllToMultiBuffer reads all content from the reader into a MultiBuffer, until EOF.
func ReadAllToMultiBuffer(reader io.Reader) (MultiBuffer, error) {
mb := NewMultiBufferCap(128)
for {
b := New()
err := b.Reset(ReadFrom(reader))
if b.IsEmpty() {
b.Release()
} else {
mb.Append(b)
}
if err != nil {
if errors.Cause(err) == io.EOF {
return mb, nil
}
mb.Release()
return nil, err
}
}
}
type MultiBufferReader interface {
ReadMultiBuffer() (MultiBuffer, error)
// ReadAllToBytes reads all content from the reader into a byte array, until EOF.
func ReadAllToBytes(reader io.Reader) ([]byte, error) {
mb, err := ReadAllToMultiBuffer(reader)
if err != nil {
return nil, err
}
b := make([]byte, mb.Len())
common.Must2(mb.Read(b))
mb.Release()
return b, nil
}
// MultiBuffer is a list of Buffers. The order of Buffer matters.
type MultiBuffer []*Buffer
// NewMultiBuffer creates a new MultiBuffer instance.
func NewMultiBuffer() MultiBuffer {
return MultiBuffer(make([]*Buffer, 0, 128))
// NewMultiBufferCap creates a new MultiBuffer instance.
func NewMultiBufferCap(capacity int) MultiBuffer {
return MultiBuffer(make([]*Buffer, 0, capacity))
}
// NewMultiBufferValue wraps a list of Buffers into MultiBuffer.
@ -23,14 +55,17 @@ func NewMultiBufferValue(b ...*Buffer) MultiBuffer {
return MultiBuffer(b)
}
// Append appends buffer to the end of this MultiBuffer
func (mb *MultiBuffer) Append(buf *Buffer) {
*mb = append(*mb, buf)
}
// AppendMulti appends a MultiBuffer to the end of this one.
func (mb *MultiBuffer) AppendMulti(buf MultiBuffer) {
*mb = append(*mb, buf...)
}
// Copy copied the begining part of the MultiBuffer into the given byte array.
func (mb MultiBuffer) Copy(b []byte) int {
total := 0
for _, bb := range mb {
@ -43,6 +78,7 @@ func (mb MultiBuffer) Copy(b []byte) int {
return total
}
// Read implements io.Reader.
func (mb *MultiBuffer) Read(b []byte) (int, error) {
endIndex := len(*mb)
totalBytes := 0
@ -52,6 +88,7 @@ func (mb *MultiBuffer) Read(b []byte) (int, error) {
b = b[nBytes:]
if bb.IsEmpty() {
bb.Release()
(*mb)[i] = nil
} else {
endIndex = i
break
@ -61,6 +98,7 @@ func (mb *MultiBuffer) Read(b []byte) (int, error) {
return totalBytes, nil
}
// Write implements io.Writer.
func (mb *MultiBuffer) Write(b []byte) {
n := len(*mb)
if n > 0 && !(*mb)[n-1].IsFull() {
@ -96,11 +134,12 @@ func (mb MultiBuffer) IsEmpty() bool {
}
// Release releases all Buffers in the MultiBuffer.
func (mb MultiBuffer) Release() {
for i, b := range mb {
func (mb *MultiBuffer) Release() {
for i, b := range *mb {
b.Release()
mb[i] = nil
(*mb)[i] = nil
}
*mb = nil
}
// ToNetBuffers converts this MultiBuffer to net.Buffers. The return net.Buffers points to the same content of the MultiBuffer.
@ -112,8 +151,9 @@ func (mb MultiBuffer) ToNetBuffers() net.Buffers {
return bs
}
// SliceBySize splits the begining of this MultiBuffer into another one, for at most size bytes.
func (mb *MultiBuffer) SliceBySize(size int) MultiBuffer {
slice := NewMultiBuffer()
slice := NewMultiBufferCap(10)
sliceSize := 0
endIndex := len(*mb)
for i, b := range *mb {
@ -123,16 +163,24 @@ func (mb *MultiBuffer) SliceBySize(size int) MultiBuffer {
}
sliceSize += b.Len()
slice.Append(b)
(*mb)[i] = nil
}
*mb = (*mb)[endIndex:]
if endIndex == 0 && len(*mb) > 0 {
b := New()
common.Must(b.Reset(ReadFullFrom((*mb)[0], size)))
return NewMultiBufferValue(b)
}
return slice
}
// SplitFirst splits out the first Buffer in this MultiBuffer.
func (mb *MultiBuffer) SplitFirst() *Buffer {
if len(*mb) == 0 {
return nil
}
b := (*mb)[0]
(*mb)[0] = nil
*mb = (*mb)[1:]
return b
}

View File

@ -4,11 +4,11 @@ import (
"testing"
. "v2ray.com/core/common/buf"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func TestMultiBufferRead(t *testing.T) {
assert := assert.On(t)
assert := With(t)
b1 := New()
b1.AppendBytes('a', 'b')
@ -19,17 +19,17 @@ func TestMultiBufferRead(t *testing.T) {
bs := make([]byte, 32)
nBytes, err := mb.Read(bs)
assert.Error(err).IsNil()
assert.Int(nBytes).Equals(4)
assert.Bytes(bs[:nBytes]).Equals([]byte("abcd"))
assert(err, IsNil)
assert(nBytes, Equals, 4)
assert(bs[:nBytes], Equals, []byte("abcd"))
}
func TestMultiBufferAppend(t *testing.T) {
assert := assert.On(t)
assert := With(t)
var mb MultiBuffer
b := New()
b.AppendBytes('a', 'b')
mb.Append(b)
assert.Int(mb.Len()).Equals(2)
assert(mb.Len(), Equals, 2)
}

View File

@ -6,38 +6,86 @@ import (
"v2ray.com/core/common/errors"
)
var (
_ Reader = (*BytesToBufferReader)(nil)
_ io.Reader = (*BytesToBufferReader)(nil)
)
// BytesToBufferReader is a Reader that adjusts its reading speed automatically.
type BytesToBufferReader struct {
reader io.Reader
io.Reader
buffer []byte
}
// Read implements Reader.Read().
func (r *BytesToBufferReader) Read() (MultiBuffer, error) {
nBytes, err := r.reader.Read(r.buffer)
if err != nil {
return nil, err
func NewBytesToBufferReader(reader io.Reader) Reader {
return &BytesToBufferReader{
Reader: reader,
}
}
func (r *BytesToBufferReader) readSmall() (MultiBuffer, error) {
b := New()
err := b.Reset(ReadFrom(r.Reader))
if b.IsFull() {
r.buffer = make([]byte, 32*1024)
}
if !b.IsEmpty() {
return NewMultiBufferValue(b), nil
}
b.Release()
return nil, err
}
// ReadMultiBuffer implements Reader.
func (r *BytesToBufferReader) ReadMultiBuffer() (MultiBuffer, error) {
if r.buffer == nil {
return r.readSmall()
}
mb := NewMultiBuffer()
mb.Write(r.buffer[:nBytes])
return mb, nil
nBytes, err := r.Reader.Read(r.buffer)
if nBytes > 0 {
mb := NewMultiBufferCap(nBytes/Size + 1)
mb.Write(r.buffer[:nBytes])
return mb, nil
}
return nil, err
}
type readerAdpater struct {
MultiBufferReader
}
var (
_ Reader = (*BufferedReader)(nil)
_ io.Reader = (*BufferedReader)(nil)
_ io.ByteReader = (*BufferedReader)(nil)
_ io.WriterTo = (*BufferedReader)(nil)
)
func (r *readerAdpater) Read() (MultiBuffer, error) {
return r.ReadMultiBuffer()
}
type bufferToBytesReader struct {
type BufferedReader struct {
stream Reader
leftOver MultiBuffer
buffered bool
}
func (r *bufferToBytesReader) Read(b []byte) (int, error) {
func NewBufferedReader(reader Reader) *BufferedReader {
return &BufferedReader{
stream: reader,
buffered: true,
}
}
func (r *BufferedReader) SetBuffered(f bool) {
r.buffered = f
}
func (r *BufferedReader) IsBuffered() bool {
return r.buffered
}
func (r *BufferedReader) ReadByte() (byte, error) {
var b [1]byte
_, err := r.Read(b[:])
return b[0], err
}
func (r *BufferedReader) Read(b []byte) (int, error) {
if r.leftOver != nil {
nBytes, _ := r.leftOver.Read(b)
if r.leftOver.IsEmpty() {
@ -47,51 +95,75 @@ func (r *bufferToBytesReader) Read(b []byte) (int, error) {
return nBytes, nil
}
mb, err := r.stream.Read()
if err != nil {
return 0, err
if !r.buffered {
if reader, ok := r.stream.(io.Reader); ok {
return reader.Read(b)
}
}
nBytes, _ := mb.Read(b)
if !mb.IsEmpty() {
r.leftOver = mb
mb, err := r.stream.ReadMultiBuffer()
if mb != nil {
nBytes, _ := mb.Read(b)
if !mb.IsEmpty() {
r.leftOver = mb
}
return nBytes, err
}
return nBytes, nil
return 0, err
}
func (r *bufferToBytesReader) ReadMultiBuffer() (MultiBuffer, error) {
func (r *BufferedReader) ReadMultiBuffer() (MultiBuffer, error) {
if r.leftOver != nil {
mb := r.leftOver
r.leftOver = nil
return mb, nil
}
return r.stream.Read()
return r.stream.ReadMultiBuffer()
}
func (r *bufferToBytesReader) writeToInternal(writer io.Writer) (int64, error) {
// ReadAtMost returns a MultiBuffer with at most size.
func (r *BufferedReader) ReadAtMost(size int) (MultiBuffer, error) {
if r.leftOver == nil {
mb, err := r.stream.ReadMultiBuffer()
if mb.IsEmpty() && err != nil {
return nil, err
}
r.leftOver = mb
}
mb := r.leftOver.SliceBySize(size)
if r.leftOver.IsEmpty() {
r.leftOver = nil
}
return mb, nil
}
func (r *BufferedReader) writeToInternal(writer io.Writer) (int64, error) {
mbWriter := NewWriter(writer)
totalBytes := int64(0)
if r.leftOver != nil {
if err := mbWriter.Write(r.leftOver); err != nil {
totalBytes += int64(r.leftOver.Len())
if err := mbWriter.WriteMultiBuffer(r.leftOver); err != nil {
return 0, err
}
totalBytes += int64(r.leftOver.Len())
}
for {
mb, err := r.stream.Read()
if err != nil {
return totalBytes, err
mb, err := r.stream.ReadMultiBuffer()
if mb != nil {
totalBytes += int64(mb.Len())
if werr := mbWriter.WriteMultiBuffer(mb); werr != nil {
return totalBytes, err
}
}
totalBytes += int64(mb.Len())
if err := mbWriter.Write(mb); err != nil {
if err != nil {
return totalBytes, err
}
}
}
func (r *bufferToBytesReader) WriteTo(writer io.Writer) (int64, error) {
func (r *BufferedReader) WriteTo(writer io.Writer) (int64, error) {
nBytes, err := r.writeToInternal(writer)
if errors.Cause(err) == io.EOF {
return nBytes, nil

View File

@ -7,64 +7,66 @@ import (
"testing"
. "v2ray.com/core/common/buf"
"v2ray.com/core/testing/assert"
"v2ray.com/core/transport/ray"
. "v2ray.com/ext/assert"
)
func TestAdaptiveReader(t *testing.T) {
assert := assert.On(t)
assert := With(t)
rawContent := make([]byte, 1024*1024)
buffer := bytes.NewBuffer(rawContent)
reader := NewReader(bytes.NewReader(make([]byte, 1024*1024)))
b, err := reader.ReadMultiBuffer()
assert(err, IsNil)
assert(b.Len(), Equals, 2*1024)
reader := NewReader(buffer)
b, err := reader.Read()
assert.Error(err).IsNil()
assert.Int(b.Len()).Equals(32 * 1024)
b, err = reader.ReadMultiBuffer()
assert(err, IsNil)
assert(b.Len(), Equals, 32*1024)
}
func TestBytesReaderWriteTo(t *testing.T) {
assert := assert.On(t)
assert := With(t)
stream := ray.NewStream(context.Background())
reader := ToBytesReader(stream)
reader := NewBufferedReader(stream)
b1 := New()
b1.AppendBytes('a', 'b', 'c')
b2 := New()
b2.AppendBytes('e', 'f', 'g')
assert.Error(stream.Write(NewMultiBufferValue(b1, b2))).IsNil()
assert(stream.WriteMultiBuffer(NewMultiBufferValue(b1, b2)), IsNil)
stream.Close()
stream2 := ray.NewStream(context.Background())
writer := ToBytesWriter(stream2)
writer := NewBufferedWriter(stream2)
writer.SetBuffered(false)
nBytes, err := io.Copy(writer, reader)
assert.Error(err).IsNil()
assert.Int64(nBytes).Equals(6)
assert(err, IsNil)
assert(nBytes, Equals, int64(6))
mb, err := stream2.Read()
assert.Error(err).IsNil()
assert.Int(len(mb)).Equals(2)
assert.String(mb[0].String()).Equals("abc")
assert.String(mb[1].String()).Equals("efg")
mb, err := stream2.ReadMultiBuffer()
assert(err, IsNil)
assert(len(mb), Equals, 2)
assert(mb[0].String(), Equals, "abc")
assert(mb[1].String(), Equals, "efg")
}
func TestBytesReaderMultiBuffer(t *testing.T) {
assert := assert.On(t)
assert := With(t)
stream := ray.NewStream(context.Background())
reader := ToBytesReader(stream)
reader := NewBufferedReader(stream)
b1 := New()
b1.AppendBytes('a', 'b', 'c')
b2 := New()
b2.AppendBytes('e', 'f', 'g')
assert.Error(stream.Write(NewMultiBufferValue(b1, b2))).IsNil()
assert(stream.WriteMultiBuffer(NewMultiBufferValue(b1, b2)), IsNil)
stream.Close()
mbReader := NewReader(reader)
mb, err := mbReader.Read()
assert.Error(err).IsNil()
assert.Int(len(mb)).Equals(2)
assert.String(mb[0].String()).Equals("abc")
assert.String(mb[1].String()).Equals("efg")
mb, err := mbReader.ReadMultiBuffer()
assert(err, IsNil)
assert(len(mb), Equals, 2)
assert(mb[0].String(), Equals, "abc")
assert(mb[1].String(), Equals, "efg")
}

View File

@ -1,52 +1,164 @@
package buf
import "io"
import (
"io"
"v2ray.com/core/common/errors"
)
var (
_ io.ReaderFrom = (*BufferToBytesWriter)(nil)
_ io.Writer = (*BufferToBytesWriter)(nil)
_ Writer = (*BufferToBytesWriter)(nil)
)
// BufferToBytesWriter is a Writer that writes alloc.Buffer into underlying writer.
type BufferToBytesWriter struct {
writer io.Writer
io.Writer
}
// Write implements Writer.Write(). Write() takes ownership of the given buffer.
func (w *BufferToBytesWriter) Write(mb MultiBuffer) error {
func NewBufferToBytesWriter(writer io.Writer) *BufferToBytesWriter {
return &BufferToBytesWriter{
Writer: writer,
}
}
// WriteMultiBuffer implements Writer. This method takes ownership of the given buffer.
func (w *BufferToBytesWriter) WriteMultiBuffer(mb MultiBuffer) error {
defer mb.Release()
bs := mb.ToNetBuffers()
_, err := bs.WriteTo(w.writer)
_, err := bs.WriteTo(w)
return err
}
type writerAdapter struct {
writer MultiBufferWriter
// ReadFrom implements io.ReaderFrom.
func (w *BufferToBytesWriter) ReadFrom(reader io.Reader) (int64, error) {
var sc SizeCounter
err := Copy(NewReader(reader), w, CountSize(&sc))
return sc.Size, err
}
// Write implements buf.MultiBufferWriter.
func (w *writerAdapter) Write(mb MultiBuffer) error {
return w.writer.WriteMultiBuffer(mb)
var (
_ io.ReaderFrom = (*BufferedWriter)(nil)
_ io.Writer = (*BufferedWriter)(nil)
_ Writer = (*BufferedWriter)(nil)
_ io.ByteWriter = (*BufferedWriter)(nil)
)
// BufferedWriter is a Writer with internal buffer.
type BufferedWriter struct {
writer Writer
buffer *Buffer
buffered bool
}
type mergingWriter struct {
writer io.Writer
buffer []byte
// NewBufferedWriter creates a new BufferedWriter.
func NewBufferedWriter(writer Writer) *BufferedWriter {
return &BufferedWriter{
writer: writer,
buffer: New(),
buffered: true,
}
}
func (w *mergingWriter) Write(mb MultiBuffer) error {
defer mb.Release()
func (w *BufferedWriter) WriteByte(c byte) error {
_, err := w.Write([]byte{c})
return err
}
for !mb.IsEmpty() {
nBytes, _ := mb.Read(w.buffer)
if _, err := w.writer.Write(w.buffer[:nBytes]); err != nil {
// Write implements io.Writer.
func (w *BufferedWriter) Write(b []byte) (int, error) {
if !w.buffered {
if writer, ok := w.writer.(io.Writer); ok {
return writer.Write(b)
}
}
totalBytes := 0
for len(b) > 0 {
if w.buffer == nil {
w.buffer = New()
}
nBytes, err := w.buffer.Write(b)
totalBytes += nBytes
if err != nil {
return totalBytes, err
}
if !w.buffered || w.buffer.IsFull() {
if err := w.Flush(); err != nil {
return totalBytes, err
}
}
b = b[nBytes:]
}
return totalBytes, nil
}
// WriteMultiBuffer implements Writer. It takes ownership of the given MultiBuffer.
func (w *BufferedWriter) WriteMultiBuffer(b MultiBuffer) error {
if !w.buffered {
return w.writer.WriteMultiBuffer(b)
}
defer b.Release()
for !b.IsEmpty() {
if err := w.buffer.AppendSupplier(ReadFrom(&b)); err != nil {
return err
}
if w.buffer.IsFull() {
if err := w.Flush(); err != nil {
return err
}
}
}
return nil
}
// Flush flushes buffered content into underlying writer.
func (w *BufferedWriter) Flush() error {
if !w.buffer.IsEmpty() {
if err := w.writer.WriteMultiBuffer(NewMultiBufferValue(w.buffer)); err != nil {
return err
}
if w.buffered {
w.buffer = New()
} else {
w.buffer = nil
}
}
return nil
}
func (w *BufferedWriter) SetBuffered(f bool) error {
w.buffered = f
if !f {
return w.Flush()
}
return nil
}
// ReadFrom implements io.ReaderFrom.
func (w *BufferedWriter) ReadFrom(reader io.Reader) (int64, error) {
if err := w.SetBuffered(false); err != nil {
return 0, err
}
var sc SizeCounter
err := Copy(NewReader(reader), w, CountSize(&sc))
return sc.Size, err
}
type seqWriter struct {
writer io.Writer
}
func (w *seqWriter) Write(mb MultiBuffer) error {
func (w *seqWriter) WriteMultiBuffer(mb MultiBuffer) error {
defer mb.Release()
for _, b := range mb {
@ -61,54 +173,38 @@ func (w *seqWriter) Write(mb MultiBuffer) error {
return nil
}
var (
_ MultiBufferWriter = (*bytesToBufferWriter)(nil)
)
type bytesToBufferWriter struct {
writer Writer
}
// Write implements io.Writer.
func (w *bytesToBufferWriter) Write(payload []byte) (int, error) {
mb := NewMultiBuffer()
mb.Write(payload)
if err := w.writer.Write(mb); err != nil {
return 0, err
}
return len(payload), nil
}
func (w *bytesToBufferWriter) WriteMultiBuffer(mb MultiBuffer) error {
return w.writer.Write(mb)
}
func (w *bytesToBufferWriter) ReadFrom(reader io.Reader) (int64, error) {
mbReader := NewReader(reader)
totalBytes := int64(0)
eof := false
for !eof {
mb, err := mbReader.Read()
if err == io.EOF {
eof = true
} else if err != nil {
return totalBytes, err
}
totalBytes += int64(mb.Len())
if err := w.writer.Write(mb); err != nil {
return totalBytes, err
}
}
return totalBytes, nil
}
type noOpWriter struct{}
func (noOpWriter) Write(b MultiBuffer) error {
func (noOpWriter) WriteMultiBuffer(b MultiBuffer) error {
b.Release()
return nil
}
func (noOpWriter) Write(b []byte) (int, error) {
return len(b), nil
}
func (noOpWriter) ReadFrom(reader io.Reader) (int64, error) {
b := New()
defer b.Release()
totalBytes := int64(0)
for {
err := b.Reset(ReadFrom(reader))
totalBytes += int64(b.Len())
if err != nil {
if errors.Cause(err) == io.EOF {
return totalBytes, nil
}
return totalBytes, err
}
}
}
var (
// Discard is a Writer that swallows all contents written in.
Discard Writer = noOpWriter{}
// DiscardBytes is an io.Writer that swallows all contents written in.
DiscardBytes io.Writer = noOpWriter{}
)

View File

@ -9,37 +9,67 @@ import (
"context"
"io"
"v2ray.com/core/common"
. "v2ray.com/core/common/buf"
"v2ray.com/core/testing/assert"
"v2ray.com/core/transport/ray"
. "v2ray.com/ext/assert"
)
func TestWriter(t *testing.T) {
assert := assert.On(t)
assert := With(t)
lb := New()
assert.Error(lb.AppendSupplier(ReadFrom(rand.Reader))).IsNil()
assert(lb.AppendSupplier(ReadFrom(rand.Reader)), IsNil)
expectedBytes := append([]byte(nil), lb.Bytes()...)
writeBuffer := bytes.NewBuffer(make([]byte, 0, 1024*1024))
writer := NewWriter(NewBufferedWriter(writeBuffer))
err := writer.Write(NewMultiBufferValue(lb))
assert.Error(err).IsNil()
assert.Bytes(expectedBytes).Equals(writeBuffer.Bytes())
writer := NewBufferedWriter(NewWriter(writeBuffer))
writer.SetBuffered(false)
err := writer.WriteMultiBuffer(NewMultiBufferValue(lb))
assert(err, IsNil)
assert(writer.Flush(), IsNil)
assert(expectedBytes, Equals, writeBuffer.Bytes())
}
func TestBytesWriterReadFrom(t *testing.T) {
assert := assert.On(t)
assert := With(t)
cache := ray.NewStream(context.Background())
reader := bufio.NewReader(io.LimitReader(rand.Reader, 8192))
_, err := reader.WriteTo(ToBytesWriter(cache))
assert.Error(err).IsNil()
const size = 50000
reader := bufio.NewReader(io.LimitReader(rand.Reader, size))
writer := NewBufferedWriter(cache)
writer.SetBuffered(false)
nBytes, err := reader.WriteTo(writer)
assert(nBytes, Equals, int64(size))
assert(err, IsNil)
mb, err := cache.Read()
assert.Error(err).IsNil()
assert.Int(mb.Len()).Equals(8192)
assert.Int(len(mb)).Equals(4)
mb, err := cache.ReadMultiBuffer()
assert(err, IsNil)
assert(mb.Len(), Equals, size)
}
func TestDiscardBytes(t *testing.T) {
assert := With(t)
b := New()
common.Must(b.Reset(ReadFullFrom(rand.Reader, Size)))
nBytes, err := io.Copy(DiscardBytes, b)
assert(nBytes, Equals, int64(Size))
assert(err, IsNil)
}
func TestDiscardBytesMultiBuffer(t *testing.T) {
assert := With(t)
const size = 10240*1024 + 1
buffer := bytes.NewBuffer(make([]byte, 0, size))
common.Must2(buffer.ReadFrom(io.LimitReader(rand.Reader, size)))
r := NewReader(buffer)
nBytes, err := io.Copy(DiscardBytes, NewBufferedReader(r))
assert(nBytes, Equals, int64(size))
assert(err, IsNil)
}

View File

@ -11,6 +11,7 @@ func Must(err error) {
}
}
// Must2 panics if the second parameter is not nil.
func Must2(v interface{}, err error) {
if err != nil {
panic(err)

View File

@ -29,6 +29,26 @@ func (v StaticBytesGenerator) Next() []byte {
return v.Content
}
type IncreasingAEADNonceGenerator struct {
nonce []byte
}
func NewIncreasingAEADNonceGenerator() *IncreasingAEADNonceGenerator {
return &IncreasingAEADNonceGenerator{
nonce: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF},
}
}
func (g *IncreasingAEADNonceGenerator) Next() []byte {
for i := range g.nonce {
g.nonce[i]++
if g.nonce[i] != 0 {
break
}
}
return g.nonce
}
type Authenticator interface {
NonceSize() int
Overhead() int
@ -48,7 +68,10 @@ func (v *AEADAuthenticator) Open(dst, cipherText []byte) ([]byte, error) {
return nil, newError("invalid AEAD nonce size: ", len(iv))
}
additionalData := v.AdditionalDataGenerator.Next()
var additionalData []byte
if v.AdditionalDataGenerator != nil {
additionalData = v.AdditionalDataGenerator.Next()
}
return v.AEAD.Open(dst, iv, cipherText, additionalData)
}
@ -58,7 +81,10 @@ func (v *AEADAuthenticator) Seal(dst, plainText []byte) ([]byte, error) {
return nil, newError("invalid AEAD nonce size: ", len(iv))
}
additionalData := v.AdditionalDataGenerator.Next()
var additionalData []byte
if v.AdditionalDataGenerator != nil {
additionalData = v.AdditionalDataGenerator.Next()
}
return v.AEAD.Seal(dst, iv, plainText, additionalData), nil
}
@ -93,7 +119,12 @@ func (r *AuthenticationReader) readSize() error {
sizeBytes := r.sizeParser.SizeBytes()
if r.buffer.Len() < sizeBytes {
r.buffer.Reset(buf.ReadFrom(r.buffer))
if r.buffer.IsEmpty() {
r.buffer.Clear()
} else {
common.Must(r.buffer.Reset(buf.ReadFrom(r.buffer)))
}
delta := sizeBytes - r.buffer.Len()
if err := r.buffer.AppendSupplier(buf.ReadAtLeastFrom(r.reader, delta)); err != nil {
return err
@ -146,18 +177,18 @@ func (r *AuthenticationReader) readChunk(waitForData bool) ([]byte, error) {
return b, nil
}
func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) {
func (r *AuthenticationReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
b, err := r.readChunk(true)
if err != nil {
return nil, err
}
mb := buf.NewMultiBuffer()
var mb buf.MultiBuffer
if r.transferType == protocol.TransferTypeStream {
mb.Write(b)
} else {
var bb *buf.Buffer
if len(b) < buf.Size {
if len(b) <= buf.Size {
bb = buf.New()
} else {
bb = buf.NewLocal(len(b))
@ -175,7 +206,7 @@ func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) {
mb.Write(b)
} else {
var bb *buf.Buffer
if len(b) < buf.Size {
if len(b) <= buf.Size {
bb = buf.New()
} else {
bb = buf.NewLocal(len(b))
@ -190,79 +221,92 @@ func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) {
type AuthenticationWriter struct {
auth Authenticator
buffer []byte
payload []byte
writer *buf.BufferedWriter
writer buf.Writer
sizeParser ChunkSizeEncoder
transferType protocol.TransferType
}
func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, writer io.Writer, transferType protocol.TransferType) *AuthenticationWriter {
const payloadSize = 1024
return &AuthenticationWriter{
auth: auth,
buffer: make([]byte, payloadSize+sizeParser.SizeBytes()+auth.Overhead()),
payload: make([]byte, payloadSize),
writer: buf.NewBufferedWriterSize(writer, readerBufferSize),
writer: buf.NewWriter(writer),
sizeParser: sizeParser,
transferType: transferType,
}
}
func (w *AuthenticationWriter) append(b []byte) error {
encryptedSize := len(b) + w.auth.Overhead()
buffer := w.sizeParser.Encode(uint16(encryptedSize), w.buffer[:0])
func (w *AuthenticationWriter) seal(b *buf.Buffer) (*buf.Buffer, error) {
encryptedSize := b.Len() + w.auth.Overhead()
buffer, err := w.auth.Seal(buffer, b)
if err != nil {
return err
eb := buf.New()
common.Must(eb.Reset(func(bb []byte) (int, error) {
w.sizeParser.Encode(uint16(encryptedSize), bb[:0])
return w.sizeParser.SizeBytes(), nil
}))
if err := eb.AppendSupplier(func(bb []byte) (int, error) {
_, err := w.auth.Seal(bb[:0], b.Bytes())
return encryptedSize, err
}); err != nil {
eb.Release()
return nil, err
}
if _, err := w.writer.Write(buffer); err != nil {
return err
}
return nil
return eb, nil
}
func (w *AuthenticationWriter) writeStream(mb buf.MultiBuffer) error {
defer mb.Release()
payloadSize := buf.Size - w.auth.Overhead() - w.sizeParser.SizeBytes()
mb2Write := buf.NewMultiBufferCap(len(mb) + 10)
for {
n, _ := mb.Read(w.payload)
if err := w.append(w.payload[:n]); err != nil {
b := buf.New()
common.Must(b.Reset(func(bb []byte) (int, error) {
return mb.Read(bb[:payloadSize])
}))
eb, err := w.seal(b)
b.Release()
if err != nil {
mb2Write.Release()
return err
}
mb2Write.Append(eb)
if mb.IsEmpty() {
break
}
}
return w.writer.Flush()
return w.writer.WriteMultiBuffer(mb2Write)
}
func (w *AuthenticationWriter) writePacket(mb buf.MultiBuffer) error {
defer mb.Release()
mb2Write := buf.NewMultiBufferCap(len(mb) * 2)
for {
b := mb.SplitFirst()
if b == nil {
b = buf.New()
}
if err := w.append(b.Bytes()); err != nil {
b.Release()
eb, err := w.seal(b)
b.Release()
if err != nil {
mb2Write.Release()
return err
}
b.Release()
mb2Write.Append(eb)
if mb.IsEmpty() {
break
}
}
return w.writer.Flush()
return w.writer.WriteMultiBuffer(mb2Write)
}
func (w *AuthenticationWriter) Write(mb buf.MultiBuffer) error {
func (w *AuthenticationWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
if w.transferType == protocol.TransferTypeStream {
return w.writeStream(mb)
}

View File

@ -10,19 +10,19 @@ import (
"v2ray.com/core/common/buf"
. "v2ray.com/core/common/crypto"
"v2ray.com/core/common/protocol"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func TestAuthenticationReaderWriter(t *testing.T) {
assert := assert.On(t)
assert := With(t)
key := make([]byte, 16)
rand.Read(key)
block, err := aes.NewCipher(key)
assert.Error(err).IsNil()
assert(err, IsNil)
aead, err := cipher.NewGCM(block)
assert.Error(err).IsNil()
assert(err, IsNil)
rawPayload := make([]byte, 8192*10)
rand.Read(rawPayload)
@ -42,10 +42,10 @@ func TestAuthenticationReaderWriter(t *testing.T) {
AdditionalDataGenerator: &NoOpBytesGenerator{},
}, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream)
assert.Error(writer.Write(buf.NewMultiBufferValue(payload))).IsNil()
assert.Int(cache.Len()).Equals(83360)
assert.Error(writer.Write(buf.NewMultiBuffer())).IsNil()
assert.Error(err).IsNil()
assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(payload)), IsNil)
assert(cache.Len(), Equals, 82658)
assert(writer.WriteMultiBuffer(buf.MultiBuffer{}), IsNil)
assert(err, IsNil)
reader := NewAuthenticationReader(&AEADAuthenticator{
AEAD: aead,
@ -55,33 +55,33 @@ func TestAuthenticationReaderWriter(t *testing.T) {
AdditionalDataGenerator: &NoOpBytesGenerator{},
}, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream)
mb := buf.NewMultiBuffer()
var mb buf.MultiBuffer
for mb.Len() < len(rawPayload) {
mb2, err := reader.Read()
assert.Error(err).IsNil()
mb2, err := reader.ReadMultiBuffer()
assert(err, IsNil)
mb.AppendMulti(mb2)
}
mbContent := make([]byte, 8192*10)
mb.Read(mbContent)
assert.Bytes(mbContent).Equals(rawPayload)
assert(mbContent, Equals, rawPayload)
_, err = reader.Read()
assert.Error(err).Equals(io.EOF)
_, err = reader.ReadMultiBuffer()
assert(err, Equals, io.EOF)
}
func TestAuthenticationReaderWriterPacket(t *testing.T) {
assert := assert.On(t)
assert := With(t)
key := make([]byte, 16)
rand.Read(key)
block, err := aes.NewCipher(key)
assert.Error(err).IsNil()
assert(err, IsNil)
aead, err := cipher.NewGCM(block)
assert.Error(err).IsNil()
assert(err, IsNil)
cache := buf.NewLocal(1024)
iv := make([]byte, 12)
@ -95,7 +95,7 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) {
AdditionalDataGenerator: &NoOpBytesGenerator{},
}, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket)
payload := buf.NewMultiBuffer()
var payload buf.MultiBuffer
pb1 := buf.New()
pb1.Append([]byte("abcd"))
payload.Append(pb1)
@ -104,10 +104,10 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) {
pb2.Append([]byte("efgh"))
payload.Append(pb2)
assert.Error(writer.Write(payload)).IsNil()
assert.Int(cache.Len()).GreaterThan(0)
assert.Error(writer.Write(buf.NewMultiBuffer())).IsNil()
assert.Error(err).IsNil()
assert(writer.WriteMultiBuffer(payload), IsNil)
assert(cache.Len(), GreaterThan, 0)
assert(writer.WriteMultiBuffer(buf.MultiBuffer{}), IsNil)
assert(err, IsNil)
reader := NewAuthenticationReader(&AEADAuthenticator{
AEAD: aead,
@ -117,15 +117,15 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) {
AdditionalDataGenerator: &NoOpBytesGenerator{},
}, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket)
mb, err := reader.Read()
assert.Error(err).IsNil()
mb, err := reader.ReadMultiBuffer()
assert(err, IsNil)
b1 := mb.SplitFirst()
assert.String(b1.String()).Equals("abcd")
assert(b1.String(), Equals, "abcd")
b2 := mb.SplitFirst()
assert.String(b2.String()).Equals("efgh")
assert.Bool(mb.IsEmpty()).IsTrue()
assert(b2.String(), Equals, "efgh")
assert(mb.IsEmpty(), IsTrue)
_, err = reader.Read()
assert.Error(err).Equals(io.EOF)
_, err = reader.ReadMultiBuffer()
assert(err, Equals, io.EOF)
}

View File

@ -7,7 +7,7 @@ import (
"v2ray.com/core/common"
. "v2ray.com/core/common/crypto"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func mustDecodeHex(s string) []byte {
@ -17,7 +17,7 @@ func mustDecodeHex(s string) []byte {
}
func TestChaCha20Stream(t *testing.T) {
assert := assert.On(t)
assert := With(t)
var cases = []struct {
key []byte
@ -51,12 +51,12 @@ func TestChaCha20Stream(t *testing.T) {
input := make([]byte, len(c.output))
actualOutout := make([]byte, len(c.output))
s.XORKeyStream(actualOutout, input)
assert.Bytes(c.output).Equals(actualOutout)
assert(c.output, Equals, actualOutout)
}
}
func TestChaCha20Decoding(t *testing.T) {
assert := assert.On(t)
assert := With(t)
key := make([]byte, 32)
rand.Read(key)
@ -72,5 +72,5 @@ func TestChaCha20Decoding(t *testing.T) {
stream2 := NewChaCha20Stream(key, iv)
stream2.XORKeyStream(x, x)
assert.Bytes(x).Equals(payload)
assert(x, Equals, payload)
}

View File

@ -3,15 +3,18 @@ package crypto
import (
"io"
"v2ray.com/core/common"
"v2ray.com/core/common/buf"
"v2ray.com/core/common/serial"
)
// ChunkSizeDecoder is an utility class to decode size value from bytes.
type ChunkSizeDecoder interface {
SizeBytes() int
Decode([]byte) (uint16, error)
}
// ChunkSizeEncoder is an utility class to encode size value into bytes.
type ChunkSizeEncoder interface {
SizeBytes() int
Encode(uint16, []byte) []byte
@ -31,50 +34,53 @@ func (PlainChunkSizeParser) Decode(b []byte) (uint16, error) {
return serial.BytesToUint16(b), nil
}
type AEADChunkSizeParser struct {
Auth *AEADAuthenticator
}
func (p *AEADChunkSizeParser) SizeBytes() int {
return 2 + p.Auth.Overhead()
}
func (p *AEADChunkSizeParser) Encode(size uint16, b []byte) []byte {
b = serial.Uint16ToBytes(size-uint16(p.Auth.Overhead()), b)
b, err := p.Auth.Seal(b[:0], b)
common.Must(err)
return b
}
func (p *AEADChunkSizeParser) Decode(b []byte) (uint16, error) {
b, err := p.Auth.Open(b[:0], b)
if err != nil {
return 0, err
}
return serial.BytesToUint16(b) + uint16(p.Auth.Overhead()), nil
}
type ChunkStreamReader struct {
sizeDecoder ChunkSizeDecoder
reader buf.Reader
reader *buf.BufferedReader
buffer []byte
leftOver buf.MultiBuffer
leftOverSize int
}
func NewChunkStreamReader(sizeDecoder ChunkSizeDecoder, reader io.Reader) *ChunkStreamReader {
return &ChunkStreamReader{
sizeDecoder: sizeDecoder,
reader: buf.NewReader(reader),
reader: buf.NewBufferedReader(buf.NewReader(reader)),
buffer: make([]byte, sizeDecoder.SizeBytes()),
}
}
func (r *ChunkStreamReader) readAtLeast(size int) error {
mb := r.leftOver
r.leftOver = nil
for mb.Len() < size {
extra, err := r.reader.Read()
if err != nil {
mb.Release()
return err
}
mb.AppendMulti(extra)
}
r.leftOver = mb
return nil
}
func (r *ChunkStreamReader) readSize() (uint16, error) {
if r.sizeDecoder.SizeBytes() > r.leftOver.Len() {
if err := r.readAtLeast(r.sizeDecoder.SizeBytes() - r.leftOver.Len()); err != nil {
return 0, err
}
if _, err := io.ReadFull(r.reader, r.buffer); err != nil {
return 0, err
}
r.leftOver.Read(r.buffer)
return r.sizeDecoder.Decode(r.buffer)
}
func (r *ChunkStreamReader) Read() (buf.MultiBuffer, error) {
func (r *ChunkStreamReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
size := r.leftOverSize
if size == 0 {
nextSize, err := r.readSize()
@ -86,29 +92,14 @@ func (r *ChunkStreamReader) Read() (buf.MultiBuffer, error) {
}
size = int(nextSize)
}
r.leftOverSize = size
if r.leftOver.IsEmpty() {
if err := r.readAtLeast(1); err != nil {
return nil, err
}
}
if size >= r.leftOver.Len() {
mb := r.leftOver
r.leftOverSize = size - r.leftOver.Len()
r.leftOver = nil
mb, err := r.reader.ReadAtMost(size)
if !mb.IsEmpty() {
r.leftOverSize -= mb.Len()
return mb, nil
}
mb := r.leftOver.SliceBySize(size)
if mb.Len() != size {
b := buf.New()
b.AppendSupplier(buf.ReadFullFrom(&r.leftOver, size-mb.Len()))
mb.Append(b)
}
r.leftOverSize = 0
return mb, nil
return nil, err
}
type ChunkStreamWriter struct {
@ -123,18 +114,19 @@ func NewChunkStreamWriter(sizeEncoder ChunkSizeEncoder, writer io.Writer) *Chunk
}
}
func (w *ChunkStreamWriter) Write(mb buf.MultiBuffer) error {
mb2Write := buf.NewMultiBuffer()
func (w *ChunkStreamWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
const sliceSize = 8192
mbLen := mb.Len()
mb2Write := buf.NewMultiBufferCap(mbLen/buf.Size + mbLen/sliceSize + 2)
for {
slice := mb.SliceBySize(sliceSize)
b := buf.New()
b.AppendSupplier(func(buffer []byte) (int, error) {
common.Must(b.Reset(func(buffer []byte) (int, error) {
w.sizeEncoder.Encode(uint16(slice.Len()), buffer[:0])
return w.sizeEncoder.SizeBytes(), nil
})
}))
mb2Write.Append(b)
mb2Write.AppendMulti(slice)
@ -143,5 +135,5 @@ func (w *ChunkStreamWriter) Write(mb buf.MultiBuffer) error {
}
}
return w.writer.Write(mb2Write)
return w.writer.WriteMultiBuffer(mb2Write)
}

View File

@ -6,11 +6,11 @@ import (
"v2ray.com/core/common/buf"
. "v2ray.com/core/common/crypto"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func TestChunkStreamIO(t *testing.T) {
assert := assert.On(t)
assert := With(t)
cache := buf.NewLocal(8192)
@ -19,26 +19,26 @@ func TestChunkStreamIO(t *testing.T) {
b := buf.New()
b.AppendBytes('a', 'b', 'c', 'd')
assert.Error(writer.Write(buf.NewMultiBufferValue(b))).IsNil()
assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)), IsNil)
b = buf.New()
b.AppendBytes('e', 'f', 'g')
assert.Error(writer.Write(buf.NewMultiBufferValue(b))).IsNil()
assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)), IsNil)
assert.Error(writer.Write(buf.NewMultiBuffer())).IsNil()
assert(writer.WriteMultiBuffer(buf.MultiBuffer{}), IsNil)
assert.Int(cache.Len()).Equals(13)
assert(cache.Len(), Equals, 13)
mb, err := reader.Read()
assert.Error(err).IsNil()
assert.Int(mb.Len()).Equals(4)
assert.Bytes(mb[0].Bytes()).Equals([]byte("abcd"))
mb, err := reader.ReadMultiBuffer()
assert(err, IsNil)
assert(mb.Len(), Equals, 4)
assert(mb[0].Bytes(), Equals, []byte("abcd"))
mb, err = reader.Read()
assert.Error(err).IsNil()
assert.Int(mb.Len()).Equals(3)
assert.Bytes(mb[0].Bytes()).Equals([]byte("efg"))
mb, err = reader.ReadMultiBuffer()
assert(err, IsNil)
assert(mb.Len(), Equals, 3)
assert(mb[0].Bytes(), Equals, []byte("efg"))
_, err = reader.Read()
assert.Error(err).Equals(io.EOF)
_, err = reader.ReadMultiBuffer()
assert(err, Equals, io.EOF)
}

View File

@ -28,7 +28,7 @@ func (r *CryptionReader) Read(data []byte) (int, error) {
}
var (
_ buf.MultiBufferWriter = (*CryptionWriter)(nil)
_ buf.Writer = (*CryptionWriter)(nil)
)
type CryptionWriter struct {
@ -51,6 +51,8 @@ func (w *CryptionWriter) Write(data []byte) (int, error) {
}
func (w *CryptionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
defer mb.Release()
bs := mb.ToNetBuffers()
for _, b := range bs {
w.stream.XORKeyStream(b, b)

View File

@ -5,29 +5,29 @@ import (
"testing"
. "v2ray.com/core/common/errors"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func TestError(t *testing.T) {
assert := assert.On(t)
assert := With(t)
err := New("TestError")
assert.Bool(GetSeverity(err) == SeverityInfo).IsTrue()
assert(GetSeverity(err), Equals, SeverityInfo)
err = New("TestError2").Base(io.EOF)
assert.Bool(GetSeverity(err) == SeverityInfo).IsTrue()
assert(GetSeverity(err), Equals, SeverityInfo)
err = New("TestError3").Base(io.EOF).AtWarning()
assert.Bool(GetSeverity(err) == SeverityWarning).IsTrue()
assert(GetSeverity(err), Equals, SeverityWarning)
err = New("TestError4").Base(io.EOF).AtWarning()
err = New("TestError5").Base(err)
assert.Bool(GetSeverity(err) == SeverityWarning).IsTrue()
assert.String(err.Error()).Contains("EOF")
assert(GetSeverity(err), Equals, SeverityWarning)
assert(err.Error(), HasSubstring, "EOF")
}
func TestErrorMessage(t *testing.T) {
assert := assert.On(t)
assert := With(t)
data := []struct {
err error
@ -44,6 +44,6 @@ func TestErrorMessage(t *testing.T) {
}
for _, d := range data {
assert.String(d.err.Error()).Equals(d.msg)
assert(d.err.Error(), Equals, d.msg)
}
}

46
common/event/event.go Normal file
View File

@ -0,0 +1,46 @@
package event
import "sync"
type Event uint16
type Handler func(data interface{}) error
type Registry interface {
On(Event, Handler)
}
type Listener struct {
sync.RWMutex
events map[Event][]Handler
}
func (l *Listener) On(e Event, h Handler) {
l.Lock()
defer l.Unlock()
if l.events == nil {
l.events = make(map[Event][]Handler)
}
handlers := l.events[e]
handlers = append(handlers, h)
l.events[e] = handlers
}
func (l *Listener) Fire(e Event, data interface{}) error {
l.RLock()
defer l.RUnlock()
if l.events == nil {
return nil
}
for _, h := range l.events[e] {
if err := h(data); err != nil {
return err
}
}
return nil
}

View File

@ -73,6 +73,12 @@ type Address interface {
// ParseAddress parses a string into an Address. The return value will be an IPAddress when
// the string is in the form of IPv4 or IPv6 address, or a DomainAddress otherwise.
func ParseAddress(addr string) Address {
// Handle IPv6 address in form as "[2001:4860:0:2001::68]"
lenAddr := len(addr)
if lenAddr > 0 && addr[0] == '[' && addr[lenAddr-1] == ']' {
addr = addr[1 : lenAddr-1]
}
ip := net.ParseIP(addr)
if ip != nil {
return IPAddress(ip)

View File

@ -5,24 +5,25 @@ import (
"testing"
. "v2ray.com/core/common/net"
"v2ray.com/core/testing/assert"
. "v2ray.com/core/common/net/testing"
. "v2ray.com/ext/assert"
)
func TestIPv4Address(t *testing.T) {
assert := assert.On(t)
assert := With(t)
ip := []byte{byte(1), byte(2), byte(3), byte(4)}
addr := IPAddress(ip)
assert.Address(addr).IsIPv4()
assert.Address(addr).IsNotIPv6()
assert.Address(addr).IsNotDomain()
assert.Bytes(addr.IP()).Equals(ip)
assert.Address(addr).EqualsString("1.2.3.4")
assert(addr, IsIPv4)
assert(addr, Not(IsIPv6))
assert(addr, Not(IsDomain))
assert([]byte(addr.IP()), Equals, ip)
assert(addr.String(), Equals, "1.2.3.4")
}
func TestIPv6Address(t *testing.T) {
assert := assert.On(t)
assert := With(t)
ip := []byte{
byte(1), byte(2), byte(3), byte(4),
@ -32,15 +33,15 @@ func TestIPv6Address(t *testing.T) {
}
addr := IPAddress(ip)
assert.Address(addr).IsIPv6()
assert.Address(addr).IsNotIPv4()
assert.Address(addr).IsNotDomain()
assert.IP(addr.IP()).Equals(net.IP(ip))
assert.Address(addr).EqualsString("[102:304:102:304:102:304:102:304]")
assert(addr, IsIPv6)
assert(addr, Not(IsIPv4))
assert(addr, Not(IsDomain))
assert(addr.IP(), Equals, net.IP(ip))
assert(addr.String(), Equals, "[102:304:102:304:102:304:102:304]")
}
func TestIPv4Asv6(t *testing.T) {
assert := assert.On(t)
assert := With(t)
ip := []byte{
byte(0), byte(0), byte(0), byte(0),
byte(0), byte(0), byte(0), byte(0),
@ -48,27 +49,55 @@ func TestIPv4Asv6(t *testing.T) {
byte(1), byte(2), byte(3), byte(4),
}
addr := IPAddress(ip)
assert.Address(addr).EqualsString("1.2.3.4")
assert(addr.String(), Equals, "1.2.3.4")
}
func TestDomainAddress(t *testing.T) {
assert := assert.On(t)
assert := With(t)
domain := "v2ray.com"
addr := DomainAddress(domain)
assert.Address(addr).IsDomain()
assert.Address(addr).IsNotIPv6()
assert.Address(addr).IsNotIPv4()
assert.String(addr.Domain()).Equals(domain)
assert.Address(addr).EqualsString("v2ray.com")
assert(addr, IsDomain)
assert(addr, Not(IsIPv6))
assert(addr, Not(IsIPv4))
assert(addr.Domain(), Equals, domain)
assert(addr.String(), Equals, "v2ray.com")
}
func TestNetIPv4Address(t *testing.T) {
assert := assert.On(t)
assert := With(t)
ip := net.IPv4(1, 2, 3, 4)
addr := IPAddress(ip)
assert.Address(addr).IsIPv4()
assert.Address(addr).EqualsString("1.2.3.4")
assert(addr, IsIPv4)
assert(addr.String(), Equals, "1.2.3.4")
}
func TestParseIPv6Address(t *testing.T) {
assert := With(t)
ip := ParseAddress("[2001:4860:0:2001::68]")
assert(ip, IsIPv6)
assert(ip.String(), Equals, "[2001:4860:0:2001::68]")
ip = ParseAddress("[::ffff:123.151.71.143]")
assert(ip, IsIPv4)
assert(ip.String(), Equals, "123.151.71.143")
}
func TestInvalidAddressConvertion(t *testing.T) {
assert := With(t)
assert(func() { ParseAddress("8.8.8.8").Domain() }, Panics)
assert(func() { ParseAddress("2001:4860:0:2001::68").Domain() }, Panics)
assert(func() { ParseAddress("v2ray.com").IP() }, Panics)
}
func TestIPOrDomain(t *testing.T) {
assert := With(t)
assert(NewIPOrDomain(ParseAddress("v2ray.com")).AsAddress(), Equals, ParseAddress("v2ray.com"))
assert(NewIPOrDomain(ParseAddress("8.8.8.8")).AsAddress(), Equals, ParseAddress("8.8.8.8"))
assert(NewIPOrDomain(ParseAddress("2001:4860:0:2001::68")).AsAddress(), Equals, ParseAddress("2001:4860:0:2001::68"))
}

View File

@ -41,14 +41,17 @@ func UDPDestination(address Address, port Port) Destination {
}
}
// NetAddr returns the network address in this Destination in string form.
func (d Destination) NetAddr() string {
return d.Address.String() + ":" + d.Port.String()
}
// String returns the strings form of this Destination.
func (d Destination) String() string {
return d.Network.URLPrefix() + ":" + d.NetAddr()
}
// IsValid returns true if this Destination is valid.
func (d Destination) IsValid() bool {
return d.Network != Network_Unknown
}

View File

@ -4,23 +4,24 @@ import (
"testing"
. "v2ray.com/core/common/net"
"v2ray.com/core/testing/assert"
. "v2ray.com/core/common/net/testing"
. "v2ray.com/ext/assert"
)
func TestTCPDestination(t *testing.T) {
assert := assert.On(t)
assert := With(t)
dest := TCPDestination(IPAddress([]byte{1, 2, 3, 4}), 80)
assert.Destination(dest).IsTCP()
assert.Destination(dest).IsNotUDP()
assert.Destination(dest).EqualsString("tcp:1.2.3.4:80")
assert(dest, IsTCP)
assert(dest, Not(IsUDP))
assert(dest.String(), Equals, "tcp:1.2.3.4:80")
}
func TestUDPDestination(t *testing.T) {
assert := assert.On(t)
assert := With(t)
dest := UDPDestination(IPAddress([]byte{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x88, 0x88}), 53)
assert.Destination(dest).IsNotTCP()
assert.Destination(dest).IsUDP()
assert.Destination(dest).EqualsString("udp:[2001:4860:4860::8888]:53")
assert(dest, Not(IsTCP))
assert(dest, IsUDP)
assert(dest.String(), Equals, "udp:[2001:4860:4860::8888]:53")
}

View File

@ -44,6 +44,7 @@ func (n *IPNetTable) Add(ipNet *net.IPNet) {
func (n *IPNetTable) AddIP(ip []byte, mask byte) {
k := ipToUint32(ip)
k = (k >> (32 - mask)) << (32 - mask) // normalize ip
existing, found := n.cache[k]
if !found || existing > mask {
n.cache[k] = mask

View File

@ -2,11 +2,19 @@ package net_test
import (
"net"
"os"
"path/filepath"
"testing"
proto "github.com/golang/protobuf/proto"
"v2ray.com/core/app/router"
"v2ray.com/core/common/platform"
"v2ray.com/ext/sysio"
"v2ray.com/core/common"
. "v2ray.com/core/common/net"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func parseCIDR(str string) *net.IPNet {
@ -16,7 +24,7 @@ func parseCIDR(str string) *net.IPNet {
}
func TestIPNet(t *testing.T) {
assert := assert.On(t)
assert := With(t)
ipNet := NewIPNetTable()
ipNet.Add(parseCIDR(("0.0.0.0/8")))
@ -32,12 +40,94 @@ func TestIPNet(t *testing.T) {
ipNet.Add(parseCIDR(("198.51.100.0/24")))
ipNet.Add(parseCIDR(("203.0.113.0/24")))
ipNet.Add(parseCIDR(("8.8.8.8/32")))
assert.Bool(ipNet.Contains(ParseIP("192.168.1.1"))).IsTrue()
assert.Bool(ipNet.Contains(ParseIP("192.0.0.0"))).IsTrue()
assert.Bool(ipNet.Contains(ParseIP("192.0.1.0"))).IsFalse()
assert.Bool(ipNet.Contains(ParseIP("0.1.0.0"))).IsTrue()
assert.Bool(ipNet.Contains(ParseIP("1.0.0.1"))).IsFalse()
assert.Bool(ipNet.Contains(ParseIP("8.8.8.7"))).IsFalse()
assert.Bool(ipNet.Contains(ParseIP("8.8.8.8"))).IsTrue()
assert.Bool(ipNet.Contains(ParseIP("2001:cdba::3257:9652"))).IsFalse()
ipNet.AddIP(net.ParseIP("91.108.4.0"), 16)
assert(ipNet.Contains(ParseIP("192.168.1.1")), IsTrue)
assert(ipNet.Contains(ParseIP("192.0.0.0")), IsTrue)
assert(ipNet.Contains(ParseIP("192.0.1.0")), IsFalse)
assert(ipNet.Contains(ParseIP("0.1.0.0")), IsTrue)
assert(ipNet.Contains(ParseIP("1.0.0.1")), IsFalse)
assert(ipNet.Contains(ParseIP("8.8.8.7")), IsFalse)
assert(ipNet.Contains(ParseIP("8.8.8.8")), IsTrue)
assert(ipNet.Contains(ParseIP("2001:cdba::3257:9652")), IsFalse)
assert(ipNet.Contains(ParseIP("91.108.255.254")), IsTrue)
}
func TestGeoIPCN(t *testing.T) {
assert := With(t)
common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "tools", "release", "config", "geoip.dat")))
ips, err := loadGeoIP("CN")
common.Must(err)
ipNet := NewIPNetTable()
for _, ip := range ips {
ipNet.AddIP(ip.Ip, byte(ip.Prefix))
}
assert(ipNet.Contains([]byte{8, 8, 8, 8}), IsFalse)
}
func loadGeoIP(country string) ([]*router.CIDR, error) {
geoipBytes, err := sysio.ReadAsset("geoip.dat")
if err != nil {
return nil, err
}
var geoipList router.GeoIPList
if err := proto.Unmarshal(geoipBytes, &geoipList); err != nil {
return nil, err
}
for _, geoip := range geoipList.Entry {
if geoip.CountryCode == country {
return geoip.Cidr, nil
}
}
panic("country not found: " + country)
}
func BenchmarkIPNetQuery(b *testing.B) {
common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "tools", "release", "config", "geoip.dat")))
ips, err := loadGeoIP("CN")
common.Must(err)
ipNet := NewIPNetTable()
for _, ip := range ips {
ipNet.AddIP(ip.Ip, byte(ip.Prefix))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
ipNet.Contains([]byte{8, 8, 8, 8})
}
}
func BenchmarkCIDRQuery(b *testing.B) {
common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "tools", "release", "config", "geoip.dat")))
ips, err := loadGeoIP("CN")
common.Must(err)
ipNet := make([]*net.IPNet, 0, 1024)
for _, ip := range ips {
if len(ip.Ip) != 4 {
continue
}
ipNet = append(ipNet, &net.IPNet{
IP: net.IP(ip.Ip),
Mask: net.CIDRMask(int(ip.Prefix), 32),
})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, n := range ipNet {
if n.Contains([]byte{8, 8, 8, 8}) {
break
}
}
}
}

View File

@ -1,4 +1,4 @@
// Package net contains common network utilities.
// Package net is a drop-in replacement to Golang's net package, with some more functionalities.
package net
//go:generate go run $GOPATH/src/v2ray.com/core/tools/generrorgen/main.go -pkg net -path Net

View File

@ -4,15 +4,15 @@ import (
"testing"
. "v2ray.com/core/common/net"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func TestPortRangeContains(t *testing.T) {
assert := assert.On(t)
assert := With(t)
portRange := &PortRange{
From: 53,
To: 53,
}
assert.Bool(portRange.Contains(Port(53))).IsTrue()
assert(portRange.Contains(Port(53)), IsTrue)
}

View File

@ -2,6 +2,7 @@ package net
import "net"
// DialTCP is an injectable function. Default to net.DialTCP
var DialTCP = net.DialTCP
var DialUDP = net.DialUDP
var DialUnix = net.DialUnix
@ -31,6 +32,7 @@ type UDPConn = net.UDPConn
type UnixAddr = net.UnixAddr
type UnixConn = net.UnixConn
// IP is an alias for net.IP.
type IP = net.IP
type IPMask = net.IPMask
type IPNet = net.IPNet

View File

@ -0,0 +1,48 @@
package testing
import (
"v2ray.com/core/common/net"
"v2ray.com/ext/assert"
)
var IsIPv4 = assert.CreateMatcher(func(a net.Address) bool {
return a.Family().IsIPv4()
}, "is IPv4")
var IsIPv6 = assert.CreateMatcher(func(a net.Address) bool {
return a.Family().IsIPv6()
}, "is IPv6")
var IsIP = assert.CreateMatcher(func(a net.Address) bool {
return a.Family().IsIPv4() || a.Family().IsIPv6()
}, "is IP")
var IsTCP = assert.CreateMatcher(func(a net.Destination) bool {
return a.Network == net.Network_TCP
}, "is TCP")
var IsUDP = assert.CreateMatcher(func(a net.Destination) bool {
return a.Network == net.Network_UDP
}, "is UDP")
var IsDomain = assert.CreateMatcher(func(a net.Address) bool {
return a.Family().IsDomain()
}, "is Domain")
func init() {
assert.RegisterEqualsMatcher(func(a, b net.Address) bool {
return a == b
})
assert.RegisterEqualsMatcher(func(a, b net.Destination) bool {
return a == b
})
assert.RegisterEqualsMatcher(func(a, b net.Port) bool {
return a == b
})
assert.RegisterEqualsMatcher(func(a, b net.IP) bool {
return a.Equal(b)
})
}

View File

@ -4,6 +4,7 @@ package platform
import (
"os"
"path/filepath"
)
func ExpandEnv(s string) string {
@ -13,3 +14,9 @@ func ExpandEnv(s string) string {
func LineSeparator() string {
return "\n"
}
func GetToolLocation(file string) string {
const name = "v2ray.location.tool"
toolPath := EnvFlag{Name: name, AltName: NormalizeEnvName(name)}.GetValue(getExecutableDir)
return filepath.Join(toolPath, file)
}

View File

@ -2,6 +2,7 @@ package platform
import (
"os"
"path/filepath"
"strconv"
"strings"
)
@ -11,7 +12,7 @@ type EnvFlag struct {
AltName string
}
func (f EnvFlag) GetValue(defaultValue string) string {
func (f EnvFlag) GetValue(defaultValue func() string) string {
if v, found := os.LookupEnv(f.Name); found {
return v
}
@ -21,13 +22,16 @@ func (f EnvFlag) GetValue(defaultValue string) string {
}
}
return defaultValue
return defaultValue()
}
func (f EnvFlag) GetValueAsInt(defaultValue int) int {
const PlaceHolder = "xxxxxx"
s := f.GetValue(PlaceHolder)
if s == PlaceHolder {
useDefaultValue := false
s := f.GetValue(func() string {
useDefaultValue = true
return ""
})
if useDefaultValue {
return defaultValue
}
v, err := strconv.ParseInt(s, 10, 32)
@ -40,3 +44,29 @@ func (f EnvFlag) GetValueAsInt(defaultValue int) int {
func NormalizeEnvName(name string) string {
return strings.Replace(strings.ToUpper(strings.TrimSpace(name)), ".", "_", -1)
}
func getExecutableDir() string {
exec, err := os.Executable()
if err != nil {
return ""
}
return filepath.Dir(exec)
}
func getExecuableSubDir(dir string) func() string {
return func() string {
return filepath.Join(getExecutableDir(), dir)
}
}
func GetAssetLocation(file string) string {
const name = "v2ray.location.asset"
assetPath := EnvFlag{Name: name, AltName: NormalizeEnvName(name)}.GetValue(getExecutableDir)
return filepath.Join(assetPath, file)
}
func GetPluginDirectory() string {
const name = "v2ray.location.plugin"
pluginDir := EnvFlag{Name: name, AltName: NormalizeEnvName(name)}.GetValue(getExecuableSubDir("plugins"))
return pluginDir
}

View File

@ -1,14 +1,16 @@
package platform_test
import (
"os"
"path/filepath"
"testing"
. "v2ray.com/core/common/platform"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func TestNormalizeEnvName(t *testing.T) {
assert := assert.On(t)
assert := With(t)
cases := []struct {
input string
@ -28,14 +30,27 @@ func TestNormalizeEnvName(t *testing.T) {
},
}
for _, test := range cases {
assert.String(NormalizeEnvName(test.input)).Equals(test.output)
assert(NormalizeEnvName(test.input), Equals, test.output)
}
}
func TestEnvFlag(t *testing.T) {
assert := assert.On(t)
assert := With(t)
assert.Int(EnvFlag{
assert(EnvFlag{
Name: "xxxxx.y",
}.GetValueAsInt(10)).Equals(10)
}.GetValueAsInt(10), Equals, 10)
}
func TestGetAssetLocation(t *testing.T) {
assert := With(t)
exec, err := os.Executable()
assert(err, IsNil)
loc := GetAssetLocation("t")
assert(filepath.Dir(loc), Equals, filepath.Dir(exec))
os.Setenv("v2ray.location.asset", "/v2ray")
assert(GetAssetLocation("t"), Equals, "/v2ray/t")
}

View File

@ -2,6 +2,8 @@
package platform
import "path/filepath"
func ExpandEnv(s string) string {
// TODO
return s
@ -10,3 +12,9 @@ func ExpandEnv(s string) string {
func LineSeparator() string {
return "\r\n"
}
func GetToolLocation(file string) string {
const name = "v2ray.location.tool"
toolPath := EnvFlag{Name: name, AltName: NormalizeEnvName(name)}.GetValue(getExecutableDir)
return filepath.Join(toolPath, file+".exe")
}

View File

@ -3,6 +3,7 @@ package protocol
import (
"runtime"
"v2ray.com/core/common/bitmask"
"v2ray.com/core/common/net"
"v2ray.com/core/common/uuid"
)
@ -24,35 +25,20 @@ func (c RequestCommand) TransferType() TransferType {
return TransferTypePacket
}
// RequestOption is the options of a request.
type RequestOption byte
const (
// RequestOptionChunkStream indicates request payload is chunked. Each chunk consists of length, authentication and payload.
RequestOptionChunkStream = RequestOption(0x01)
RequestOptionChunkStream bitmask.Byte = 0x01
// RequestOptionConnectionReuse indicates client side expects to reuse the connection.
RequestOptionConnectionReuse = RequestOption(0x02)
RequestOptionConnectionReuse bitmask.Byte = 0x02
RequestOptionChunkMasking = RequestOption(0x04)
RequestOptionChunkMasking bitmask.Byte = 0x04
)
func (v RequestOption) Has(option RequestOption) bool {
return (v & option) == option
}
func (v *RequestOption) Set(option RequestOption) {
*v = (*v | option)
}
func (v *RequestOption) Clear(option RequestOption) {
*v = (*v & (^option))
}
type Security byte
func (v Security) Is(t SecurityType) bool {
return v == Security(t)
func (s Security) Is(t SecurityType) bool {
return s == Security(t)
}
func NormSecurity(s Security) Security {
@ -65,42 +51,28 @@ func NormSecurity(s Security) Security {
type RequestHeader struct {
Version byte
Command RequestCommand
Option RequestOption
Option bitmask.Byte
Security Security
Port net.Port
Address net.Address
User *User
}
func (v *RequestHeader) Destination() net.Destination {
if v.Command == RequestCommandUDP {
return net.UDPDestination(v.Address, v.Port)
func (h *RequestHeader) Destination() net.Destination {
if h.Command == RequestCommandUDP {
return net.UDPDestination(h.Address, h.Port)
}
return net.TCPDestination(v.Address, v.Port)
return net.TCPDestination(h.Address, h.Port)
}
type ResponseOption byte
const (
ResponseOptionConnectionReuse = ResponseOption(0x01)
ResponseOptionConnectionReuse bitmask.Byte = 0x01
)
func (v *ResponseOption) Set(option ResponseOption) {
*v = (*v | option)
}
func (v ResponseOption) Has(option ResponseOption) bool {
return (v & option) == option
}
func (v *ResponseOption) Clear(option ResponseOption) {
*v = (*v & (^option))
}
type ResponseCommand interface{}
type ResponseHeader struct {
Option ResponseOption
Option bitmask.Byte
Command ResponseCommand
}
@ -108,20 +80,21 @@ type CommandSwitchAccount struct {
Host net.Address
Port net.Port
ID *uuid.UUID
AlterIds uint16
Level uint32
AlterIds uint16
ValidMin byte
}
func (v *SecurityConfig) AsSecurity() Security {
if v == nil {
return Security(SecurityType_LEGACY)
}
if v.Type == SecurityType_AUTO {
func (sc *SecurityConfig) AsSecurity() Security {
if sc == nil || sc.Type == SecurityType_AUTO {
if runtime.GOARCH == "amd64" || runtime.GOARCH == "s390x" {
return Security(SecurityType_AES128_GCM)
}
return Security(SecurityType_CHACHA20_POLY1305)
}
return NormSecurity(Security(v.Type))
return NormSecurity(Security(sc.Type))
}
func IsDomainTooLong(domain string) bool {
return len(domain) > 256
}

View File

@ -1,34 +0,0 @@
package protocol_test
import (
"testing"
. "v2ray.com/core/common/protocol"
"v2ray.com/core/testing/assert"
)
func TestRequestOptionSet(t *testing.T) {
assert := assert.On(t)
var option RequestOption
assert.Bool(option.Has(RequestOptionChunkStream)).IsFalse()
option.Set(RequestOptionChunkStream)
assert.Bool(option.Has(RequestOptionChunkStream)).IsTrue()
option.Set(RequestOptionChunkMasking)
assert.Bool(option.Has(RequestOptionChunkMasking)).IsTrue()
assert.Bool(option.Has(RequestOptionChunkStream)).IsTrue()
}
func TestRequestOptionClear(t *testing.T) {
assert := assert.On(t)
var option RequestOption
option.Set(RequestOptionChunkStream)
option.Set(RequestOptionChunkMasking)
option.Clear(RequestOptionChunkStream)
assert.Bool(option.Has(RequestOptionChunkStream)).IsFalse()
assert.Bool(option.Has(RequestOptionChunkMasking)).IsTrue()
}

View File

@ -6,12 +6,12 @@ import (
"v2ray.com/core/common/predicate"
. "v2ray.com/core/common/protocol"
"v2ray.com/core/common/uuid"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func TestCmdKey(t *testing.T) {
assert := assert.On(t)
assert := With(t)
id := NewID(uuid.New())
assert.Bool(predicate.BytesAll(id.CmdKey(), 0)).IsFalse()
assert(predicate.BytesAll(id.CmdKey(), 0), IsFalse)
}

View File

@ -6,3 +6,11 @@ const (
TransferTypeStream TransferType = 0
TransferTypePacket TransferType = 1
)
type AddressType byte
const (
AddressTypeIPv4 AddressType = 1
AddressTypeDomain AddressType = 2
AddressTypeIPv6 AddressType = 3
)

View File

@ -6,30 +6,30 @@ import (
"v2ray.com/core/common/net"
. "v2ray.com/core/common/protocol"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func TestServerList(t *testing.T) {
assert := assert.On(t)
assert := With(t)
list := NewServerList()
list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(1)), AlwaysValid()))
assert.Uint32(list.Size()).Equals(1)
assert(list.Size(), Equals, uint32(1))
list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(2)), BeforeTime(time.Now().Add(time.Second))))
assert.Uint32(list.Size()).Equals(2)
assert(list.Size(), Equals, uint32(2))
server := list.GetServer(1)
assert.Port(server.Destination().Port).Equals(2)
assert(server.Destination().Port, Equals, net.Port(2))
time.Sleep(2 * time.Second)
server = list.GetServer(1)
assert.Pointer(server).IsNil()
assert(server, IsNil)
server = list.GetServer(0)
assert.Port(server.Destination().Port).Equals(1)
assert(server.Destination().Port, Equals, net.Port(1))
}
func TestServerPicker(t *testing.T) {
assert := assert.On(t)
assert := With(t)
list := NewServerList()
list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(1)), AlwaysValid()))
@ -38,17 +38,17 @@ func TestServerPicker(t *testing.T) {
picker := NewRoundRobinServerPicker(list)
server := picker.PickServer()
assert.Port(server.Destination().Port).Equals(1)
assert(server.Destination().Port, Equals, net.Port(1))
server = picker.PickServer()
assert.Port(server.Destination().Port).Equals(2)
assert(server.Destination().Port, Equals, net.Port(2))
server = picker.PickServer()
assert.Port(server.Destination().Port).Equals(3)
assert(server.Destination().Port, Equals, net.Port(3))
server = picker.PickServer()
assert.Port(server.Destination().Port).Equals(1)
assert(server.Destination().Port, Equals, net.Port(1))
time.Sleep(2 * time.Second)
server = picker.PickServer()
assert.Port(server.Destination().Port).Equals(1)
assert(server.Destination().Port, Equals, net.Port(1))
server = picker.PickServer()
assert.Port(server.Destination().Port).Equals(1)
assert(server.Destination().Port, Equals, net.Port(1))
}

View File

@ -5,27 +5,27 @@ import (
"time"
. "v2ray.com/core/common/protocol"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func TestAlwaysValidStrategy(t *testing.T) {
assert := assert.On(t)
assert := With(t)
strategy := AlwaysValid()
assert.Bool(strategy.IsValid()).IsTrue()
assert(strategy.IsValid(), IsTrue)
strategy.Invalidate()
assert.Bool(strategy.IsValid()).IsTrue()
assert(strategy.IsValid(), IsTrue)
}
func TestTimeoutValidStrategy(t *testing.T) {
assert := assert.On(t)
assert := With(t)
strategy := BeforeTime(time.Now().Add(2 * time.Second))
assert.Bool(strategy.IsValid()).IsTrue()
assert(strategy.IsValid(), IsTrue)
time.Sleep(3 * time.Second)
assert.Bool(strategy.IsValid()).IsFalse()
assert(strategy.IsValid(), IsFalse)
strategy = BeforeTime(time.Now().Add(2 * time.Second))
strategy.Invalidate()
assert.Bool(strategy.IsValid()).IsFalse()
assert(strategy.IsValid(), IsFalse)
}

View File

@ -5,11 +5,11 @@ import (
"time"
. "v2ray.com/core/common/protocol"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func TestGenerateRandomInt64InRange(t *testing.T) {
assert := assert.On(t)
assert := With(t)
base := time.Now().Unix()
delta := 100
@ -17,7 +17,7 @@ func TestGenerateRandomInt64InRange(t *testing.T) {
for i := 0; i < 100; i++ {
val := int64(generator())
assert.Int64(val).AtMost(base + int64(delta))
assert.Int64(val).AtLeast(base - int64(delta))
assert(val, AtMost, base+int64(delta))
assert(val, AtLeast, base-int64(delta))
}
}

View File

@ -1,7 +1,5 @@
package protocol
import "time"
func (u *User) GetTypedAccount() (Account, error) {
if u.GetAccount() == nil {
return nil, newError("Account missing").AtWarning()
@ -19,20 +17,3 @@ func (u *User) GetTypedAccount() (Account, error) {
}
return nil, newError("Unknown account type: ", u.Account.Type)
}
func (u *User) GetSettings() UserSettings {
settings := UserSettings{}
switch u.Level {
case 0:
settings.PayloadTimeout = time.Second * 30
case 1:
settings.PayloadTimeout = time.Minute * 2
default:
settings.PayloadTimeout = time.Minute * 5
}
return settings
}
type UserSettings struct {
PayloadTimeout time.Duration
}

View File

@ -6,7 +6,7 @@ import (
"v2ray.com/core/common/errors"
. "v2ray.com/core/common/retry"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
var (
@ -14,7 +14,7 @@ var (
)
func TestNoRetry(t *testing.T) {
assert := assert.On(t)
assert := With(t)
startTime := time.Now().Unix()
err := Timed(10, 100000).On(func() error {
@ -22,12 +22,12 @@ func TestNoRetry(t *testing.T) {
})
endTime := time.Now().Unix()
assert.Error(err).IsNil()
assert.Int64(endTime - startTime).AtLeast(0)
assert(err, IsNil)
assert(endTime-startTime, AtLeast, int64(0))
}
func TestRetryOnce(t *testing.T) {
assert := assert.On(t)
assert := With(t)
startTime := time.Now()
called := 0
@ -40,12 +40,12 @@ func TestRetryOnce(t *testing.T) {
})
duration := time.Since(startTime)
assert.Error(err).IsNil()
assert.Int64(int64(duration / time.Millisecond)).AtLeast(900)
assert(err, IsNil)
assert(int64(duration/time.Millisecond), AtLeast, int64(900))
}
func TestRetryMultiple(t *testing.T) {
assert := assert.On(t)
assert := With(t)
startTime := time.Now()
called := 0
@ -58,12 +58,12 @@ func TestRetryMultiple(t *testing.T) {
})
duration := time.Since(startTime)
assert.Error(err).IsNil()
assert.Int64(int64(duration / time.Millisecond)).AtLeast(4900)
assert(err, IsNil)
assert(int64(duration/time.Millisecond), AtLeast, int64(4900))
}
func TestRetryExhausted(t *testing.T) {
assert := assert.On(t)
assert := With(t)
startTime := time.Now()
called := 0
@ -73,12 +73,12 @@ func TestRetryExhausted(t *testing.T) {
})
duration := time.Since(startTime)
assert.Error(errors.Cause(err)).Equals(ErrRetryFailed)
assert.Int64(int64(duration / time.Millisecond)).AtLeast(1900)
assert(errors.Cause(err), Equals, ErrRetryFailed)
assert(int64(duration/time.Millisecond), AtLeast, int64(1900))
}
func TestExponentialBackoff(t *testing.T) {
assert := assert.On(t)
assert := With(t)
startTime := time.Now()
called := 0
@ -88,6 +88,6 @@ func TestExponentialBackoff(t *testing.T) {
})
duration := time.Since(startTime)
assert.Error(errors.Cause(err)).Equals(ErrRetryFailed)
assert.Int64(int64(duration / time.Millisecond)).AtLeast(4000)
assert(errors.Cause(err), Equals, ErrRetryFailed)
assert(int64(duration/time.Millisecond), AtLeast, int64(4000))
}

View File

@ -4,11 +4,11 @@ import (
"testing"
. "v2ray.com/core/common/serial"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func TestBytesToHex(t *testing.T) {
assert := assert.On(t)
assert := With(t)
cases := []struct {
input []byte
@ -21,15 +21,15 @@ func TestBytesToHex(t *testing.T) {
}
for _, test := range cases {
assert.String(test.output).Equals(BytesToHexString(test.input))
assert(test.output, Equals, BytesToHexString(test.input))
}
}
func TestInt64(t *testing.T) {
assert := assert.On(t)
assert := With(t)
x := int64(375134875348)
b := Int64ToBytes(x, []byte{})
v := BytesToInt64(b)
assert.Int64(x).Equals(v)
assert(x, Equals, v)
}

View File

@ -6,15 +6,15 @@ import (
"v2ray.com/core/common"
"v2ray.com/core/common/buf"
. "v2ray.com/core/common/serial"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func TestUint32(t *testing.T) {
assert := assert.On(t)
assert := With(t)
x := uint32(458634234)
s1 := Uint32ToBytes(x, []byte{})
s2 := buf.New()
common.Must(s2.AppendSupplier(WriteUint32(x)))
assert.Bytes(s1).Equals(s2.Bytes())
assert(s1, Equals, s2.Bytes())
}

View File

@ -4,13 +4,13 @@ import (
"testing"
. "v2ray.com/core/common/serial"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func TestGetInstance(t *testing.T) {
assert := assert.On(t)
assert := With(t)
p, err := GetInstance("")
assert.Pointer(p).IsNil()
assert.Error(err).IsNotNil()
assert(p, IsNil)
assert(err, IsNotNil)
}

View File

@ -6,11 +6,11 @@ import (
"testing"
. "v2ray.com/core/common/signal"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func TestErrorOrFinish2_Error(t *testing.T) {
assert := assert.On(t)
assert := With(t)
c1 := make(chan error, 1)
c2 := make(chan error, 2)
@ -22,11 +22,11 @@ func TestErrorOrFinish2_Error(t *testing.T) {
c1 <- errors.New("test")
err := <-c
assert.String(err.Error()).Equals("test")
assert(err.Error(), Equals, "test")
}
func TestErrorOrFinish2_Error2(t *testing.T) {
assert := assert.On(t)
assert := With(t)
c1 := make(chan error, 1)
c2 := make(chan error, 2)
@ -38,11 +38,11 @@ func TestErrorOrFinish2_Error2(t *testing.T) {
c2 <- errors.New("test")
err := <-c
assert.String(err.Error()).Equals("test")
assert(err.Error(), Equals, "test")
}
func TestErrorOrFinish2_NoneError(t *testing.T) {
assert := assert.On(t)
assert := With(t)
c1 := make(chan error, 1)
c2 := make(chan error, 2)
@ -61,11 +61,11 @@ func TestErrorOrFinish2_NoneError(t *testing.T) {
close(c2)
err := <-c
assert.Error(err).IsNil()
assert(err, IsNil)
}
func TestErrorOrFinish2_NoneError2(t *testing.T) {
assert := assert.On(t)
assert := With(t)
c1 := make(chan error, 1)
c2 := make(chan error, 2)
@ -84,5 +84,5 @@ func TestErrorOrFinish2_NoneError2(t *testing.T) {
close(c1)
err := <-c
assert.Error(err).IsNil()
assert(err, IsNil)
}

View File

@ -12,8 +12,6 @@ type ActivityUpdater interface {
type ActivityTimer struct {
updated chan bool
timeout chan time.Duration
ctx context.Context
cancel context.CancelFunc
}
func (t *ActivityTimer) Update() {
@ -27,7 +25,7 @@ func (t *ActivityTimer) SetTimeout(timeout time.Duration) {
t.timeout <- timeout
}
func (t *ActivityTimer) run() {
func (t *ActivityTimer) run(ctx context.Context, cancel context.CancelFunc) {
ticker := time.NewTicker(<-t.timeout)
defer func() {
ticker.Stop()
@ -36,32 +34,35 @@ func (t *ActivityTimer) run() {
for {
select {
case <-ticker.C:
case <-t.ctx.Done():
case <-ctx.Done():
return
case timeout := <-t.timeout:
if timeout == 0 {
cancel()
return
}
ticker.Stop()
ticker = time.NewTicker(timeout)
continue
}
select {
case <-t.updated:
// Updated keep waiting.
default:
t.cancel()
cancel()
return
}
}
}
func CancelAfterInactivity(ctx context.Context, timeout time.Duration) (context.Context, *ActivityTimer) {
ctx, cancel := context.WithCancel(ctx)
func CancelAfterInactivity(ctx context.Context, cancel context.CancelFunc, timeout time.Duration) *ActivityTimer {
timer := &ActivityTimer{
ctx: ctx,
cancel: cancel,
timeout: make(chan time.Duration, 1),
updated: make(chan bool, 1),
}
timer.timeout <- timeout
go timer.run()
return ctx, timer
go timer.run(ctx, cancel)
return timer
}

View File

@ -7,26 +7,28 @@ import (
"time"
. "v2ray.com/core/common/signal"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func TestActivityTimer(t *testing.T) {
assert := assert.On(t)
assert := With(t)
ctx, timer := CancelAfterInactivity(context.Background(), time.Second*5)
ctx, cancel := context.WithCancel(context.Background())
timer := CancelAfterInactivity(ctx, cancel, time.Second*5)
time.Sleep(time.Second * 6)
assert.Error(ctx.Err()).IsNotNil()
assert(ctx.Err(), IsNotNil)
runtime.KeepAlive(timer)
}
func TestActivityTimerUpdate(t *testing.T) {
assert := assert.On(t)
assert := With(t)
ctx, timer := CancelAfterInactivity(context.Background(), time.Second*10)
ctx, cancel := context.WithCancel(context.Background())
timer := CancelAfterInactivity(ctx, cancel, time.Second*10)
time.Sleep(time.Second * 3)
assert.Error(ctx.Err()).IsNil()
assert(ctx.Err(), IsNil)
timer.SetTimeout(time.Second * 1)
time.Sleep(time.Second * 2)
assert.Error(ctx.Err()).IsNotNil()
assert(ctx.Err(), IsNotNil)
runtime.KeepAlive(timer)
}

View File

@ -4,74 +4,74 @@ import (
"testing"
. "v2ray.com/core/common/uuid"
"v2ray.com/core/testing/assert"
. "v2ray.com/ext/assert"
)
func TestParseBytes(t *testing.T) {
assert := assert.On(t)
assert := With(t)
str := "2418d087-648d-4990-86e8-19dca1d006d3"
bytes := []byte{0x24, 0x18, 0xd0, 0x87, 0x64, 0x8d, 0x49, 0x90, 0x86, 0xe8, 0x19, 0xdc, 0xa1, 0xd0, 0x06, 0xd3}
uuid, err := ParseBytes(bytes)
assert.Error(err).IsNil()
assert.String(uuid.String()).Equals(str)
assert(err, IsNil)
assert(uuid.String(), Equals, str)
_, err = ParseBytes([]byte{1, 3, 2, 4})
assert.Error(err).IsNotNil()
assert(err, IsNotNil)
}
func TestParseString(t *testing.T) {
assert := assert.On(t)
assert := With(t)
str := "2418d087-648d-4990-86e8-19dca1d006d3"
expectedBytes := []byte{0x24, 0x18, 0xd0, 0x87, 0x64, 0x8d, 0x49, 0x90, 0x86, 0xe8, 0x19, 0xdc, 0xa1, 0xd0, 0x06, 0xd3}
uuid, err := ParseString(str)
assert.Error(err).IsNil()
assert.Bytes(uuid.Bytes()).Equals(expectedBytes)
assert(err, IsNil)
assert(uuid.Bytes(), Equals, expectedBytes)
uuid, err = ParseString("2418d087")
assert.Error(err).IsNotNil()
assert(err, IsNotNil)
uuid, err = ParseString("2418d087-648k-4990-86e8-19dca1d006d3")
assert.Error(err).IsNotNil()
assert(err, IsNotNil)
}
func TestNewUUID(t *testing.T) {
assert := assert.On(t)
assert := With(t)
uuid := New()
uuid2, err := ParseString(uuid.String())
assert.Error(err).IsNil()
assert.String(uuid.String()).Equals(uuid2.String())
assert.Bytes(uuid.Bytes()).Equals(uuid2.Bytes())
assert(err, IsNil)
assert(uuid.String(), Equals, uuid2.String())
assert(uuid.Bytes(), Equals, uuid2.Bytes())
}
func TestRandom(t *testing.T) {
assert := assert.On(t)
assert := With(t)
uuid := New()
uuid2 := New()
assert.String(uuid.String()).NotEquals(uuid2.String())
assert.Bytes(uuid.Bytes()).NotEquals(uuid2.Bytes())
assert(uuid.String(), NotEquals, uuid2.String())
assert(uuid.Bytes(), NotEquals, uuid2.Bytes())
}
func TestEquals(t *testing.T) {
assert := assert.On(t)
assert := With(t)
var uuid *UUID = nil
var uuid2 *UUID = nil
assert.Bool(uuid.Equals(uuid2)).IsTrue()
assert.Bool(uuid.Equals(New())).IsFalse()
assert(uuid.Equals(uuid2), IsTrue)
assert(uuid.Equals(New()), IsFalse)
}
func TestNext(t *testing.T) {
assert := assert.On(t)
assert := With(t)
uuid := New()
uuid2 := uuid.Next()
assert.Bool(uuid.Equals(uuid2)).IsFalse()
assert(uuid.Equals(uuid2), IsFalse)
}

View File

@ -18,9 +18,9 @@ import (
)
var (
version = "2.41"
version = "3.1"
build = "Custom"
codename = "One for all"
codename = "die Commanderin"
intro = "An unified platform for anti-censorship."
)

View File

@ -2,10 +2,11 @@ package core
import (
"io"
"io/ioutil"
"v2ray.com/core/common"
"v2ray.com/core/common/buf"
"github.com/golang/protobuf/proto"
"v2ray.com/core/common"
)
// ConfigLoader is an utility to load V2Ray config from external source.
@ -30,7 +31,10 @@ func LoadConfig(format ConfigFormat, input io.Reader) (*Config, error) {
func loadProtobufConfig(input io.Reader) (*Config, error) {
config := new(Config)
data, _ := ioutil.ReadAll(input)
data, err := buf.ReadAllToBytes(input)
if err != nil {
return nil, err
}
if err := proto.Unmarshal(data, config); err != nil {
return nil, err
}

View File

@ -1,5 +1,47 @@
// +build json
package main
import _ "v2ray.com/core/tools/conf"
import (
"io"
"os"
"os/exec"
"v2ray.com/core"
"v2ray.com/core/common/platform"
)
func jsonToProto(input io.Reader) (*core.Config, error) {
v2ctl := platform.GetToolLocation("v2ctl")
_, err := os.Stat(v2ctl)
if err != nil {
return nil, err
}
cmd := exec.Command(v2ctl, "config")
cmd.Stdin = input
cmd.Stderr = os.Stderr
stdoutReader, err := cmd.StdoutPipe()
if err != nil {
return nil, err
}
defer stdoutReader.Close()
if err := cmd.Start(); err != nil {
return nil, err
}
config, err := core.LoadConfig(core.ConfigFormat_Protobuf, stdoutReader)
cmd.Wait()
return config, err
}
func init() {
core.RegisterConfigLoader(core.ConfigFormat_JSON, func(input io.Reader) (*core.Config, error) {
config, err := jsonToProto(input)
if err != nil {
return nil, newError("failed to execute v2ctl to convert config file.").Base(err)
}
return config, nil
})
}

View File

@ -4,6 +4,7 @@ import (
// The following are necessary as they register handlers in their init functions.
_ "v2ray.com/core/app/dispatcher/impl"
_ "v2ray.com/core/app/dns/server"
_ "v2ray.com/core/app/policy/manager"
_ "v2ray.com/core/app/proxyman/inbound"
_ "v2ray.com/core/app/proxyman/outbound"
_ "v2ray.com/core/app/router"

View File

@ -22,6 +22,7 @@ var (
version = flag.Bool("version", false, "Show current version of V2Ray.")
test = flag.Bool("test", false, "Test config file only, without launching V2Ray server.")
format = flag.String("format", "json", "Format of input file.")
plugin = flag.Bool("plugin", false, "True to load plugins.")
)
func init() {
@ -67,7 +68,7 @@ func startV2Ray() (core.Server, error) {
server, err := core.New(config)
if err != nil {
return nil, newError("failed to create initialize").Base(err)
return nil, newError("failed to create server").Base(err)
}
return server, nil
@ -82,19 +83,27 @@ func main() {
return
}
if *plugin {
if err := core.LoadPlugins(); err != nil {
fmt.Println("Failed to load plugins:", err.Error())
os.Exit(-1)
}
}
server, err := startV2Ray()
if err != nil {
fmt.Println(err.Error())
return
os.Exit(-1)
}
if *test {
fmt.Println("Configuration OK.")
return
os.Exit(0)
}
if err := server.Start(); err != nil {
fmt.Println("Failed to start", err)
os.Exit(-1)
}
osSignals := make(chan os.Signal, 1)

Some files were not shown because too many files have changed in this diff Show More