mirror of
https://github.com/v2fly/v2ray-core.git
synced 2024-12-22 10:08:15 -05:00
commit
d1fa98b60b
46
.github/CODE_OF_CONDUCT.md
vendored
Normal file
46
.github/CODE_OF_CONDUCT.md
vendored
Normal file
@ -0,0 +1,46 @@
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to creating a positive environment include:
|
||||
|
||||
* Using welcoming and inclusive language
|
||||
* Being respectful of differing viewpoints and experiences
|
||||
* Gracefully accepting constructive criticism
|
||||
* Focusing on what is best for the community
|
||||
* Showing empathy towards other community members
|
||||
|
||||
Examples of unacceptable behavior by participants include:
|
||||
|
||||
* The use of sexualized language or imagery and unwelcome sexual attention or advances
|
||||
* Trolling, insulting/derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or electronic address, without explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a professional setting
|
||||
|
||||
## Our Responsibilities
|
||||
|
||||
Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior.
|
||||
|
||||
Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at love@v2ray.com. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately.
|
||||
|
||||
Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version]
|
||||
|
||||
[homepage]: http://contributor-covenant.org
|
||||
[version]: http://contributor-covenant.org/version/1/4/
|
92
.github/ISSUE_TEMPLATE
vendored
92
.github/ISSUE_TEMPLATE
vendored
@ -1,44 +1,94 @@
|
||||
提交 Issue 之前请先阅读 [Issue 指引](https://www.v2ray.com/zh_cn/chapter_01/issue.html),然后回答下面的问题,谢谢。
|
||||
Please read the [instruction](https://www.v2ray.com/en/get_started/issue.html) and answer the following questions before submitting your issue. Thank you.
|
||||
Please skip to the English section below if you don't write Chinese.
|
||||
|
||||
中文:
|
||||
提交 Issue 之前请先阅读 [Issue 指引](https://www.v2ray.com/chapter_01/issue.html),然后回答下面的问题,谢谢。
|
||||
除非特殊情况,请完整填写所有问题。不按模板发的 issue 将直接被关闭。
|
||||
|
||||
1) 你正在使用哪个版本的 V2Ray?(如果服务器和客户端使用了不同版本,请注明)
|
||||
What version of V2Ray are you using (If you deploy different version on server and client, please explicitly point out)?
|
||||
|
||||
2) 你的使用场景是什么?比如使用 Chrome 通过 Socks/VMess 代理观看 YouTube 视频。
|
||||
What's your scenario of using V2Ray? E.g., Watching YouTube videos in Chrome via Socks/VMess proxy.
|
||||
|
||||
3) 你看到的不正常的现象是什么?
|
||||
What did you see?
|
||||
3) 你看到的不正常的现象是什么?(请描述具体现象,比如访问超时,TLS 证书错误等)
|
||||
|
||||
4) 你期待看到的正确表现是怎样的?
|
||||
What's your expectation?
|
||||
|
||||
5) 请附上你的配置文件(提交 Issue 前请隐藏服务器端IP地址)。
|
||||
Please attach your configuration file (**Mask IP addresses before submit this issue**).
|
||||
|
||||
Server Configuration File(服务器端配置文件):
|
||||
5) 请附上你的配置(提交 Issue 前请隐藏服务器端IP地址)。
|
||||
|
||||
服务器端配置:
|
||||
```javascript
|
||||
// 在这里附上服务器端配置文件
|
||||
// Please attach your server configuration file here.
|
||||
```
|
||||
|
||||
Client Configuration File(客户端配置文件):
|
||||
客户端配置:
|
||||
```javascript
|
||||
// 在这里附上客户端配置文件
|
||||
// Please attach your client configuration file here.
|
||||
// 在这里附上客户端配置
|
||||
```
|
||||
|
||||
6) 请附上出错时软件输出的日志。在 Linux 中,日志通常在 `/var/log/v2ray/error.log` 文件中。
|
||||
Please attach the log file, especially the bottom lines if the file is large. Log file is usually `/var/log/v2ray/error.log` on Linux.
|
||||
6) 请附上出错时软件输出的错误日志。在 Linux 中,日志通常在 `/var/log/v2ray/error.log` 文件中。
|
||||
|
||||
Server Log File(服务器端日志):
|
||||
服务器端错误日志:
|
||||
```
|
||||
// 在这里附上服务器端日志
|
||||
// Please attach your server log here.
|
||||
```
|
||||
|
||||
Client Log File(客户端日志):
|
||||
客户端错误日志:
|
||||
```
|
||||
// 在这里附上客户端日志
|
||||
// Please attach your client log here.
|
||||
```
|
||||
|
||||
7) 请附上访问日志。在 Linux 中,日志通常在 `/var/log/v2ray/error.log` 文件中。
|
||||
```
|
||||
// 在这里附上服务器端日志
|
||||
```
|
||||
|
||||
8) 其它相关的配置文件(如 Nginx)和相关日志。
|
||||
|
||||
请预览一下你填的内容再提交。
|
||||
|
||||
如果你已经填完上面的问卷,请把下面的英文部份删除,再提交 Issue。
|
||||
Please remove the Chinese section above.
|
||||
|
||||
English:
|
||||
Please read the [instruction](https://www.v2ray.com/en/get_started/issue.html) and answer the following questions before submitting your issue. Thank you.
|
||||
Please answer all the questions with enough information. All issues not following this template will be closed immediately.
|
||||
|
||||
1) What version of V2Ray are you using (If you deploy different version on server and client, please explicitly point out)?
|
||||
|
||||
2) What's your scenario of using V2Ray? E.g., Watching YouTube videos in Chrome via Socks/VMess proxy.
|
||||
|
||||
3) What did you see? (Please describe in detail, such as timeout, fake TLS certificate etc)
|
||||
|
||||
4) What's your expectation?
|
||||
|
||||
5) Please attach your configuration file (**Mask IP addresses before submit this issue**).
|
||||
|
||||
Server configuration:
|
||||
```javascript
|
||||
// Please attach your server configuration here.
|
||||
```
|
||||
|
||||
Client configuration:
|
||||
```javascript
|
||||
// Please attach your client configuration here.
|
||||
```
|
||||
|
||||
6) Please attach error logs, especially the bottom lines if the file is large. Error log file is usually at `/var/log/v2ray/error.log` on Linux.
|
||||
|
||||
Server error log:
|
||||
```
|
||||
// Please attach your server error log here.
|
||||
```
|
||||
Client error log:
|
||||
```
|
||||
// Please attach your client error log here.
|
||||
```
|
||||
|
||||
7) Please attach access log. Access log is usually at '/var/log/v2ray/access.log' on Linux.
|
||||
```
|
||||
// Please attach your server access log here.
|
||||
```
|
||||
|
||||
Please review your issue before submitting.
|
||||
|
||||
8) Other configurations (such as Nginx) and logs.
|
||||
|
||||
|
63
.github/SUPPORT.md
vendored
Normal file
63
.github/SUPPORT.md
vendored
Normal file
@ -0,0 +1,63 @@
|
||||
# V2Ray 用户支持 (User Support)
|
||||
|
||||
**English reader please skip to the [English section](#way-to-get-support) below**
|
||||
|
||||
## 获得帮助信息的途径
|
||||
|
||||
您可以从以下渠道获取帮助:
|
||||
|
||||
1. 官方网站:[v2ray.com](https://www.v2ray.com)
|
||||
1. Github:[Issues](https://github.com/v2ray/v2ray-core/issues)
|
||||
1. Telegram:[主群](https://t.me/projectv2ray)
|
||||
|
||||
## Github Issue 规则
|
||||
|
||||
1. 请按模板填写 issue;
|
||||
1. 配置文件内容使用格式化代码段进行修饰(见下面的解释);
|
||||
1. 在提交 issue 前尝试减化配置文件,比如删除不必要 inbound / outbound 模块;
|
||||
1. 在提交 issue 前尝试确定问题所在,比如将 socks 代理换成 http 再次观察问题是否能重现;
|
||||
1. 配置文件必须结构完整,即除了必要的隐私信息之外,配置文件可以直接拿来运行。
|
||||
|
||||
**不按模板填写的 issue 将直接被关闭**
|
||||
|
||||
## 格式化代码段
|
||||
|
||||
在配置文件上下加入 Markdown 特定的修饰符,如下:
|
||||
|
||||
\`\`\`javascript
|
||||
|
||||
{
|
||||
// 配置文件内容
|
||||
}
|
||||
|
||||
\`\`\`
|
||||
|
||||
## Way to Get Support
|
||||
|
||||
You may get help in the following ways:
|
||||
|
||||
1. Office Site: [v2ray.com](https://www.v2ray.com)
|
||||
1. Github: [Issues](https://github.com/v2ray/v2ray-core/issues)
|
||||
1. Telegram: [Main Group](https://t.me/projectv2ray)
|
||||
|
||||
## Github Issue Rules
|
||||
|
||||
1. Please fill in the issue template.
|
||||
1. Decorate config file with Markdown formatter (See below).
|
||||
1. Try to simplify config file before submitting the issue, such as removing unnecessary inbound / outbound blocks.
|
||||
1. Try to determine the cause of the issue, for example, replacing socks inbound with http inbound to see if the issue still exists.
|
||||
1. Config file must be structurally complete.
|
||||
|
||||
**Any issue not following the issue template will be closed immediately.**
|
||||
|
||||
## Code formatter
|
||||
|
||||
Add the following Markdown decorator to config file content:
|
||||
|
||||
\`\`\`javascript
|
||||
|
||||
{
|
||||
// config file
|
||||
}
|
||||
|
||||
\`\`\`
|
6
.gitmodules
vendored
6
.gitmodules
vendored
@ -1,3 +1,9 @@
|
||||
[submodule "vendor/h12.me/socks"]
|
||||
path = vendor/h12.me/socks
|
||||
url = https://github.com/h12w/socks
|
||||
[submodule "vendor/github.com/shadowsocks/go-shadowsocks2"]
|
||||
path = vendor/github.com/shadowsocks/go-shadowsocks2
|
||||
url = https://github.com/shadowsocks/go-shadowsocks2
|
||||
[submodule "vendor/github.com/Yawning/chacha20"]
|
||||
path = vendor/github.com/Yawning/chacha20
|
||||
url = https://github.com/Yawning/chacha20
|
||||
|
@ -1,7 +1,7 @@
|
||||
sudo: required
|
||||
language: go
|
||||
go:
|
||||
- 1.9
|
||||
- 1.9.2
|
||||
go_import_path: v2ray.com/core
|
||||
git:
|
||||
depth: 5
|
||||
|
19
.vscode/tasks.json
vendored
19
.vscode/tasks.json
vendored
@ -1,13 +1,18 @@
|
||||
{
|
||||
"version": "0.1.0",
|
||||
"version": "2.0.0",
|
||||
"command": "go",
|
||||
"isShellCommand": true,
|
||||
"showOutput": "always",
|
||||
"type": "shell",
|
||||
"presentation": {
|
||||
"echo": true,
|
||||
"reveal": "always",
|
||||
"focus": false,
|
||||
"panel": "shared"
|
||||
},
|
||||
"tasks": [
|
||||
{
|
||||
"taskName": "build",
|
||||
"label": "build",
|
||||
"args": ["v2ray.com/core/..."],
|
||||
"isBuildCommand": true,
|
||||
"group": "build",
|
||||
"problemMatcher": {
|
||||
"owner": "go",
|
||||
"fileLocation": ["relative", "${workspaceRoot}"],
|
||||
@ -20,9 +25,9 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"taskName": "test",
|
||||
"label": "test",
|
||||
"args": ["-p", "1", "v2ray.com/core/..."],
|
||||
"isBuildCommand": false
|
||||
"group": "test"
|
||||
}
|
||||
]
|
||||
}
|
16
README.md
16
README.md
@ -1,4 +1,4 @@
|
||||
# Project V2Ray
|
||||
# Project V
|
||||
|
||||
[![Build Status][1]][2] [![codecov.io][3]][4] [![Go Report][5]][6] [![GoDoc][7]][8] [![codebeat][9]][10]
|
||||
|
||||
@ -13,11 +13,21 @@
|
||||
[9]: https://codebeat.co/badges/f2354ca8-3e24-463d-a2e3-159af73b2477 "Codebeat badge"
|
||||
[10]: https://codebeat.co/projects/github-com-v2ray-v2ray-core-master "Codebeat"
|
||||
|
||||
V2Ray 是一个模块化的代理软件包,它的目标是提供常用的代理软件模块,简化网络代理软件的开发。
|
||||
V 是一个模块化的代理软件包,它的目标是提供常用的代理软件模块,简化网络代理软件的开发。
|
||||
|
||||
[官方网站](https://www.v2ray.com/)
|
||||
|
||||
V2Ray provides building blocks for network proxy development. Read our [Wiki](https://www.v2ray.com/en/index.html) for more information.
|
||||
V provides building blocks for network proxy development. Read our [Wiki](https://www.v2ray.com/en/index.html) for more information.
|
||||
|
||||
## License
|
||||
[The MIT License (MIT)](https://raw.githubusercontent.com/v2ray/v2ray-core/master/LICENSE)
|
||||
|
||||
## Credits
|
||||
|
||||
This repo relies on the following third-party projects:
|
||||
|
||||
* In production:
|
||||
* [miekg/dns](https://github.com/miekg/dns)
|
||||
* [gorilla/websocket](https://github.com/gorilla/websocket)
|
||||
* For testing only:
|
||||
* [h12w/socks](https://github.com/h12w/socks)
|
||||
|
@ -39,7 +39,7 @@ func NewDefaultDispatcher(ctx context.Context, config *dispatcher.Config) (*Defa
|
||||
return nil, newError("no space in context")
|
||||
}
|
||||
d := &DefaultDispatcher{}
|
||||
space.OnInitialize(func() error {
|
||||
space.On(app.SpaceInitializing, func(interface{}) error {
|
||||
d.ohm = proxyman.OutboundHandlerManagerFromSpace(space)
|
||||
if d.ohm == nil {
|
||||
return newError("OutboundHandlerManager is not found in the space")
|
||||
|
@ -3,12 +3,14 @@ package impl_test
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"v2ray.com/core/app/proxyman"
|
||||
|
||||
. "v2ray.com/core/app/dispatcher/impl"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestHTTPHeaders(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
cases := []struct {
|
||||
input string
|
||||
@ -94,13 +96,13 @@ first_name=John&last_name=Doe&action=Submit`,
|
||||
|
||||
for _, test := range cases {
|
||||
domain, err := SniffHTTP([]byte(test.input))
|
||||
assert.String(domain).Equals(test.domain)
|
||||
assert.Error(err).Equals(test.err)
|
||||
assert(domain, Equals, test.domain)
|
||||
assert(err, Equals, test.err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSHeaders(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
cases := []struct {
|
||||
input []byte
|
||||
@ -180,7 +182,13 @@ func TestTLSHeaders(t *testing.T) {
|
||||
|
||||
for _, test := range cases {
|
||||
domain, err := SniffTLS(test.input)
|
||||
assert.String(domain).Equals(test.domain)
|
||||
assert.Error(err).Equals(test.err)
|
||||
assert(domain, Equals, test.domain)
|
||||
assert(err, Equals, test.err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnknownSniffer(t *testing.T) {
|
||||
assert := With(t)
|
||||
|
||||
assert(func() { NewSniffer([]proxyman.KnownProtocols{proxyman.KnownProtocols(-1)}) }, Panics)
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
"v2ray.com/core/app/dispatcher"
|
||||
"v2ray.com/core/app/log"
|
||||
"v2ray.com/core/common"
|
||||
"v2ray.com/core/common/buf"
|
||||
"v2ray.com/core/common/dice"
|
||||
"v2ray.com/core/common/net"
|
||||
@ -15,13 +16,16 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultTTL = uint32(3600)
|
||||
CleanupInterval = time.Second * 120
|
||||
CleanupThreshold = 512
|
||||
)
|
||||
|
||||
var (
|
||||
pseudoDestination = net.UDPDestination(net.LocalHostIP, net.Port(53))
|
||||
multiQuestionDNS = map[net.Address]bool{
|
||||
net.IPAddress([]byte{8, 8, 8, 8}): true,
|
||||
net.IPAddress([]byte{8, 8, 4, 4}): true,
|
||||
net.IPAddress([]byte{9, 9, 9, 9}): true,
|
||||
}
|
||||
)
|
||||
|
||||
type ARecord struct {
|
||||
@ -55,54 +59,52 @@ func NewUDPNameServer(address net.Destination, dispatcher dispatcher.Interface)
|
||||
return s
|
||||
}
|
||||
|
||||
// Private: Visible for testing.
|
||||
func (v *UDPNameServer) Cleanup() {
|
||||
func (s *UDPNameServer) Cleanup() {
|
||||
expiredRequests := make([]uint16, 0, 16)
|
||||
now := time.Now()
|
||||
v.Lock()
|
||||
for id, r := range v.requests {
|
||||
s.Lock()
|
||||
for id, r := range s.requests {
|
||||
if r.expire.Before(now) {
|
||||
expiredRequests = append(expiredRequests, id)
|
||||
close(r.response)
|
||||
}
|
||||
}
|
||||
for _, id := range expiredRequests {
|
||||
delete(v.requests, id)
|
||||
delete(s.requests, id)
|
||||
}
|
||||
v.Unlock()
|
||||
expiredRequests = nil
|
||||
s.Unlock()
|
||||
}
|
||||
|
||||
// Private: Visible for testing.
|
||||
func (v *UDPNameServer) AssignUnusedID(response chan<- *ARecord) uint16 {
|
||||
func (s *UDPNameServer) AssignUnusedID(response chan<- *ARecord) uint16 {
|
||||
var id uint16
|
||||
v.Lock()
|
||||
if len(v.requests) > CleanupThreshold && v.nextCleanup.Before(time.Now()) {
|
||||
v.nextCleanup = time.Now().Add(CleanupInterval)
|
||||
go v.Cleanup()
|
||||
s.Lock()
|
||||
if len(s.requests) > CleanupThreshold && s.nextCleanup.Before(time.Now()) {
|
||||
s.nextCleanup = time.Now().Add(CleanupInterval)
|
||||
go s.Cleanup()
|
||||
}
|
||||
|
||||
for {
|
||||
id = dice.RollUint16()
|
||||
if _, found := v.requests[id]; found {
|
||||
if _, found := s.requests[id]; found {
|
||||
continue
|
||||
}
|
||||
log.Trace(newError("add pending request id ", id).AtDebug())
|
||||
v.requests[id] = &PendingRequest{
|
||||
s.requests[id] = &PendingRequest{
|
||||
expire: time.Now().Add(time.Second * 8),
|
||||
response: response,
|
||||
}
|
||||
break
|
||||
}
|
||||
v.Unlock()
|
||||
s.Unlock()
|
||||
return id
|
||||
}
|
||||
|
||||
// Private: Visible for testing.
|
||||
func (v *UDPNameServer) HandleResponse(payload *buf.Buffer) {
|
||||
func (s *UDPNameServer) HandleResponse(payload *buf.Buffer) {
|
||||
msg := new(dns.Msg)
|
||||
err := msg.Unpack(payload.Bytes())
|
||||
if err != nil {
|
||||
if err == dns.ErrTruncated {
|
||||
log.Trace(newError("truncated message received. DNS server should still work. If you see anything abnormal, please submit an issue to v2ray-core.").AtWarning())
|
||||
} else if err != nil {
|
||||
log.Trace(newError("failed to parse DNS response").Base(err).AtWarning())
|
||||
return
|
||||
}
|
||||
@ -110,17 +112,17 @@ func (v *UDPNameServer) HandleResponse(payload *buf.Buffer) {
|
||||
IPs: make([]net.IP, 0, 16),
|
||||
}
|
||||
id := msg.Id
|
||||
ttl := DefaultTTL
|
||||
log.Trace(newError("handling response for id ", id, " content: ", msg.String()).AtDebug())
|
||||
ttl := uint32(3600) // an hour
|
||||
log.Trace(newError("handling response for id ", id, " content: ", msg).AtDebug())
|
||||
|
||||
v.Lock()
|
||||
request, found := v.requests[id]
|
||||
s.Lock()
|
||||
request, found := s.requests[id]
|
||||
if !found {
|
||||
v.Unlock()
|
||||
s.Unlock()
|
||||
return
|
||||
}
|
||||
delete(v.requests, id)
|
||||
v.Unlock()
|
||||
delete(s.requests, id)
|
||||
s.Unlock()
|
||||
|
||||
for _, rr := range msg.Answer {
|
||||
switch rr := rr.(type) {
|
||||
@ -142,8 +144,7 @@ func (v *UDPNameServer) HandleResponse(payload *buf.Buffer) {
|
||||
close(request.response)
|
||||
}
|
||||
|
||||
func (v *UDPNameServer) BuildQueryA(domain string, id uint16) *buf.Buffer {
|
||||
|
||||
func (s *UDPNameServer) BuildQueryA(domain string, id uint16) *buf.Buffer {
|
||||
msg := new(dns.Msg)
|
||||
msg.Id = id
|
||||
msg.RecursionDesired = true
|
||||
@ -153,34 +154,40 @@ func (v *UDPNameServer) BuildQueryA(domain string, id uint16) *buf.Buffer {
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}}
|
||||
if multiQuestionDNS[s.address.Address] {
|
||||
msg.Question = append(msg.Question, dns.Question{
|
||||
Name: dns.Fqdn(domain),
|
||||
Qtype: dns.TypeAAAA,
|
||||
Qclass: dns.ClassINET,
|
||||
})
|
||||
}
|
||||
|
||||
buffer := buf.New()
|
||||
buffer.AppendSupplier(func(b []byte) (int, error) {
|
||||
common.Must(buffer.Reset(func(b []byte) (int, error) {
|
||||
writtenBuffer, err := msg.PackBuffer(b)
|
||||
return len(writtenBuffer), err
|
||||
})
|
||||
}))
|
||||
|
||||
return buffer
|
||||
}
|
||||
|
||||
func (v *UDPNameServer) QueryA(domain string) <-chan *ARecord {
|
||||
func (s *UDPNameServer) QueryA(domain string) <-chan *ARecord {
|
||||
response := make(chan *ARecord, 1)
|
||||
id := v.AssignUnusedID(response)
|
||||
id := s.AssignUnusedID(response)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*8)
|
||||
v.udpServer.Dispatch(ctx, v.address, v.BuildQueryA(domain, id), v.HandleResponse)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s.udpServer.Dispatch(ctx, s.address, s.BuildQueryA(domain, id), s.HandleResponse)
|
||||
|
||||
go func() {
|
||||
for i := 0; i < 2; i++ {
|
||||
time.Sleep(time.Second)
|
||||
v.Lock()
|
||||
_, found := v.requests[id]
|
||||
v.Unlock()
|
||||
if found {
|
||||
v.udpServer.Dispatch(ctx, v.address, v.BuildQueryA(domain, id), v.HandleResponse)
|
||||
} else {
|
||||
s.Lock()
|
||||
_, found := s.requests[id]
|
||||
s.Unlock()
|
||||
if !found {
|
||||
break
|
||||
}
|
||||
s.udpServer.Dispatch(ctx, s.address, s.BuildQueryA(domain, id), s.HandleResponse)
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
@ -191,7 +198,7 @@ func (v *UDPNameServer) QueryA(domain string) <-chan *ARecord {
|
||||
type LocalNameServer struct {
|
||||
}
|
||||
|
||||
func (v *LocalNameServer) QueryA(domain string) <-chan *ARecord {
|
||||
func (*LocalNameServer) QueryA(domain string) <-chan *ARecord {
|
||||
response := make(chan *ARecord, 1)
|
||||
|
||||
go func() {
|
||||
@ -205,7 +212,7 @@ func (v *LocalNameServer) QueryA(domain string) <-chan *ARecord {
|
||||
|
||||
response <- &ARecord{
|
||||
IPs: ips,
|
||||
Expire: time.Now().Add(time.Second * time.Duration(DefaultTTL)),
|
||||
Expire: time.Now().Add(time.Hour),
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -21,11 +21,22 @@ const (
|
||||
)
|
||||
|
||||
type DomainRecord struct {
|
||||
A *ARecord
|
||||
IP []net.IP
|
||||
Expire time.Time
|
||||
LastAccess time.Time
|
||||
}
|
||||
|
||||
func (r *DomainRecord) Expired() bool {
|
||||
return r.Expire.Before(time.Now())
|
||||
}
|
||||
|
||||
func (r *DomainRecord) Inactive() bool {
|
||||
now := time.Now()
|
||||
return r.Expire.Before(now) || r.LastAccess.Add(time.Minute*5).Before(now)
|
||||
}
|
||||
|
||||
type CacheServer struct {
|
||||
sync.RWMutex
|
||||
sync.Mutex
|
||||
hosts map[string]net.IP
|
||||
records map[string]*DomainRecord
|
||||
servers []NameServer
|
||||
@ -41,7 +52,7 @@ func NewCacheServer(ctx context.Context, config *dns.Config) (*CacheServer, erro
|
||||
servers: make([]NameServer, len(config.NameServers)),
|
||||
hosts: config.GetInternalHosts(),
|
||||
}
|
||||
space.OnInitialize(func() error {
|
||||
space.On(app.SpaceInitializing, func(interface{}) error {
|
||||
disp := dispatcher.FromSpace(space)
|
||||
if disp == nil {
|
||||
return newError("dispatcher is not found in the space")
|
||||
@ -79,15 +90,33 @@ func (*CacheServer) Start() error {
|
||||
func (*CacheServer) Close() {}
|
||||
|
||||
func (s *CacheServer) GetCached(domain string) []net.IP {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
if record, found := s.records[domain]; found && record.A.Expire.After(time.Now()) {
|
||||
return record.A.IPs
|
||||
if record, found := s.records[domain]; found && !record.Expired() {
|
||||
record.LastAccess = time.Now()
|
||||
return record.IP
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *CacheServer) tryCleanup() {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
if len(s.records) > 256 {
|
||||
domains := make([]string, 0, 256)
|
||||
for d, r := range s.records {
|
||||
if r.Expired() {
|
||||
domains = append(domains, d)
|
||||
}
|
||||
}
|
||||
for _, d := range domains {
|
||||
delete(s.records, d)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CacheServer) Get(domain string) []net.IP {
|
||||
if ip, found := s.hosts[domain]; found {
|
||||
return []net.IP{ip}
|
||||
@ -99,6 +128,8 @@ func (s *CacheServer) Get(domain string) []net.IP {
|
||||
return ips
|
||||
}
|
||||
|
||||
s.tryCleanup()
|
||||
|
||||
for _, server := range s.servers {
|
||||
response := server.QueryA(domain)
|
||||
select {
|
||||
@ -108,7 +139,9 @@ func (s *CacheServer) Get(domain string) []net.IP {
|
||||
}
|
||||
s.Lock()
|
||||
s.records[domain] = &DomainRecord{
|
||||
A: a,
|
||||
IP: a.IPs,
|
||||
Expire: a.Expire,
|
||||
LastAccess: time.Now(),
|
||||
}
|
||||
s.Unlock()
|
||||
log.Trace(newError("returning ", len(a.IPs), " IPs for domain ", domain).AtDebug())
|
||||
|
105
app/dns/server/server_test.go
Normal file
105
app/dns/server/server_test.go
Normal file
@ -0,0 +1,105 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"v2ray.com/core/app"
|
||||
"v2ray.com/core/app/dispatcher"
|
||||
_ "v2ray.com/core/app/dispatcher/impl"
|
||||
. "v2ray.com/core/app/dns"
|
||||
_ "v2ray.com/core/app/dns/server"
|
||||
"v2ray.com/core/app/policy"
|
||||
_ "v2ray.com/core/app/policy/manager"
|
||||
"v2ray.com/core/app/proxyman"
|
||||
_ "v2ray.com/core/app/proxyman/outbound"
|
||||
"v2ray.com/core/common"
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/serial"
|
||||
"v2ray.com/core/proxy/freedom"
|
||||
"v2ray.com/core/testing/servers/udp"
|
||||
. "v2ray.com/ext/assert"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type staticHandler struct {
|
||||
}
|
||||
|
||||
func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
ans := new(dns.Msg)
|
||||
ans.Id = r.Id
|
||||
for _, q := range r.Question {
|
||||
if q.Name == "google.com." && q.Qtype == dns.TypeA {
|
||||
rr, _ := dns.NewRR("google.com. IN A 8.8.8.8")
|
||||
ans.Answer = append(ans.Answer, rr)
|
||||
} else if q.Name == "facebook.com." && q.Qtype == dns.TypeA {
|
||||
rr, _ := dns.NewRR("facebook.com. IN A 9.9.9.9")
|
||||
ans.Answer = append(ans.Answer, rr)
|
||||
}
|
||||
}
|
||||
w.WriteMsg(ans)
|
||||
}
|
||||
|
||||
func TestUDPServer(t *testing.T) {
|
||||
assert := With(t)
|
||||
|
||||
port := udp.PickPort()
|
||||
|
||||
dnsServer := dns.Server{
|
||||
Addr: "127.0.0.1:" + port.String(),
|
||||
Net: "udp",
|
||||
Handler: &staticHandler{},
|
||||
UDPSize: 1200,
|
||||
}
|
||||
|
||||
go dnsServer.ListenAndServe()
|
||||
|
||||
config := &Config{
|
||||
NameServers: []*net.Endpoint{
|
||||
{
|
||||
Network: net.Network_UDP,
|
||||
Address: &net.IPOrDomain{
|
||||
Address: &net.IPOrDomain_Ip{
|
||||
Ip: []byte{127, 0, 0, 1},
|
||||
},
|
||||
},
|
||||
Port: uint32(port),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
space := app.NewSpace()
|
||||
|
||||
ctx = app.ContextWithSpace(ctx, space)
|
||||
common.Must(app.AddApplicationToSpace(ctx, config))
|
||||
common.Must(app.AddApplicationToSpace(ctx, &dispatcher.Config{}))
|
||||
common.Must(app.AddApplicationToSpace(ctx, &proxyman.OutboundConfig{}))
|
||||
common.Must(app.AddApplicationToSpace(ctx, &policy.Config{}))
|
||||
|
||||
om := proxyman.OutboundHandlerManagerFromSpace(space)
|
||||
om.AddHandler(ctx, &proxyman.OutboundHandlerConfig{
|
||||
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
|
||||
})
|
||||
|
||||
common.Must(space.Initialize())
|
||||
common.Must(space.Start())
|
||||
|
||||
server := FromSpace(space)
|
||||
assert(server, IsNotNil)
|
||||
|
||||
ips := server.Get("google.com")
|
||||
assert(len(ips), Equals, 1)
|
||||
assert([]byte(ips[0]), Equals, []byte{8, 8, 8, 8})
|
||||
|
||||
ips = server.Get("facebook.com")
|
||||
assert(len(ips), Equals, 1)
|
||||
assert([]byte(ips[0]), Equals, []byte{9, 9, 9, 9})
|
||||
|
||||
dnsServer.Shutdown()
|
||||
|
||||
ips = server.Get("google.com")
|
||||
assert(len(ips), Equals, 1)
|
||||
assert([]byte(ips[0]), Equals, []byte{8, 8, 8, 8})
|
||||
}
|
@ -4,11 +4,11 @@ import (
|
||||
"testing"
|
||||
|
||||
. "v2ray.com/core/app/log/internal"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestAccessLog(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
entry := &AccessLog{
|
||||
From: "test_from",
|
||||
@ -18,8 +18,8 @@ func TestAccessLog(t *testing.T) {
|
||||
}
|
||||
|
||||
entryStr := entry.String()
|
||||
assert.String(entryStr).Contains("test_from")
|
||||
assert.String(entryStr).Contains("test_to")
|
||||
assert.String(entryStr).Contains("test_reason")
|
||||
assert.String(entryStr).Contains("Accepted")
|
||||
assert(entryStr, HasSubstring, "test_from")
|
||||
assert(entryStr, HasSubstring, "test_to")
|
||||
assert(entryStr, HasSubstring, "test_reason")
|
||||
assert(entryStr, HasSubstring, "Accepted")
|
||||
}
|
||||
|
26
app/policy/config.go
Normal file
26
app/policy/config.go
Normal file
@ -0,0 +1,26 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func (s *Second) Duration() time.Duration {
|
||||
return time.Second * time.Duration(s.Value)
|
||||
}
|
||||
|
||||
func (p *Policy) OverrideWith(another *Policy) {
|
||||
if another.Timeout != nil {
|
||||
if another.Timeout.Handshake != nil {
|
||||
p.Timeout.Handshake = another.Timeout.Handshake
|
||||
}
|
||||
if another.Timeout.ConnectionIdle != nil {
|
||||
p.Timeout.ConnectionIdle = another.Timeout.ConnectionIdle
|
||||
}
|
||||
if another.Timeout.UplinkOnly != nil {
|
||||
p.Timeout.UplinkOnly = another.Timeout.UplinkOnly
|
||||
}
|
||||
if another.Timeout.DownlinkOnly != nil {
|
||||
p.Timeout.DownlinkOnly = another.Timeout.DownlinkOnly
|
||||
}
|
||||
}
|
||||
}
|
140
app/policy/config.pb.go
Normal file
140
app/policy/config.pb.go
Normal file
@ -0,0 +1,140 @@
|
||||
package policy
|
||||
|
||||
import proto "github.com/golang/protobuf/proto"
|
||||
import fmt "fmt"
|
||||
import math "math"
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal
|
||||
var _ = fmt.Errorf
|
||||
var _ = math.Inf
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
|
||||
|
||||
type Second struct {
|
||||
Value uint32 `protobuf:"varint,1,opt,name=value" json:"value,omitempty"`
|
||||
}
|
||||
|
||||
func (m *Second) Reset() { *m = Second{} }
|
||||
func (m *Second) String() string { return proto.CompactTextString(m) }
|
||||
func (*Second) ProtoMessage() {}
|
||||
func (*Second) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
|
||||
|
||||
func (m *Second) GetValue() uint32 {
|
||||
if m != nil {
|
||||
return m.Value
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
type Policy struct {
|
||||
Timeout *Policy_Timeout `protobuf:"bytes,1,opt,name=timeout" json:"timeout,omitempty"`
|
||||
}
|
||||
|
||||
func (m *Policy) Reset() { *m = Policy{} }
|
||||
func (m *Policy) String() string { return proto.CompactTextString(m) }
|
||||
func (*Policy) ProtoMessage() {}
|
||||
func (*Policy) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
|
||||
|
||||
func (m *Policy) GetTimeout() *Policy_Timeout {
|
||||
if m != nil {
|
||||
return m.Timeout
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Timeout is a message for timeout settings in various stages, in seconds.
|
||||
type Policy_Timeout struct {
|
||||
Handshake *Second `protobuf:"bytes,1,opt,name=handshake" json:"handshake,omitempty"`
|
||||
ConnectionIdle *Second `protobuf:"bytes,2,opt,name=connection_idle,json=connectionIdle" json:"connection_idle,omitempty"`
|
||||
UplinkOnly *Second `protobuf:"bytes,3,opt,name=uplink_only,json=uplinkOnly" json:"uplink_only,omitempty"`
|
||||
DownlinkOnly *Second `protobuf:"bytes,4,opt,name=downlink_only,json=downlinkOnly" json:"downlink_only,omitempty"`
|
||||
}
|
||||
|
||||
func (m *Policy_Timeout) Reset() { *m = Policy_Timeout{} }
|
||||
func (m *Policy_Timeout) String() string { return proto.CompactTextString(m) }
|
||||
func (*Policy_Timeout) ProtoMessage() {}
|
||||
func (*Policy_Timeout) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1, 0} }
|
||||
|
||||
func (m *Policy_Timeout) GetHandshake() *Second {
|
||||
if m != nil {
|
||||
return m.Handshake
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Policy_Timeout) GetConnectionIdle() *Second {
|
||||
if m != nil {
|
||||
return m.ConnectionIdle
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Policy_Timeout) GetUplinkOnly() *Second {
|
||||
if m != nil {
|
||||
return m.UplinkOnly
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Policy_Timeout) GetDownlinkOnly() *Second {
|
||||
if m != nil {
|
||||
return m.DownlinkOnly
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Level map[uint32]*Policy `protobuf:"bytes,1,rep,name=level" json:"level,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
|
||||
}
|
||||
|
||||
func (m *Config) Reset() { *m = Config{} }
|
||||
func (m *Config) String() string { return proto.CompactTextString(m) }
|
||||
func (*Config) ProtoMessage() {}
|
||||
func (*Config) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
|
||||
|
||||
func (m *Config) GetLevel() map[uint32]*Policy {
|
||||
if m != nil {
|
||||
return m.Level
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
proto.RegisterType((*Second)(nil), "v2ray.core.app.policy.Second")
|
||||
proto.RegisterType((*Policy)(nil), "v2ray.core.app.policy.Policy")
|
||||
proto.RegisterType((*Policy_Timeout)(nil), "v2ray.core.app.policy.Policy.Timeout")
|
||||
proto.RegisterType((*Config)(nil), "v2ray.core.app.policy.Config")
|
||||
}
|
||||
|
||||
func init() { proto.RegisterFile("v2ray.com/core/app/policy/config.proto", fileDescriptor0) }
|
||||
|
||||
var fileDescriptor0 = []byte{
|
||||
// 349 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xc1, 0x4a, 0xeb, 0x40,
|
||||
0x14, 0x86, 0x49, 0x7a, 0x9b, 0x72, 0x4f, 0x6f, 0xaf, 0x32, 0x58, 0x88, 0x05, 0xa5, 0x14, 0x94,
|
||||
0xae, 0x26, 0x90, 0x6e, 0x44, 0xb1, 0x62, 0x45, 0x41, 0x10, 0x2c, 0x51, 0x14, 0xdc, 0x94, 0x71,
|
||||
0x32, 0xda, 0xd0, 0xe9, 0x9c, 0x21, 0xa6, 0x95, 0xbc, 0x86, 0x6f, 0xe0, 0xd6, 0x87, 0xf2, 0x59,
|
||||
0x24, 0x99, 0x84, 0x6c, 0x5a, 0xe9, 0x6e, 0x72, 0xf8, 0xfe, 0x8f, 0x43, 0xfe, 0x03, 0x87, 0x4b,
|
||||
0x3f, 0x66, 0x29, 0xe5, 0x38, 0xf7, 0x38, 0xc6, 0xc2, 0x63, 0x5a, 0x7b, 0x1a, 0x65, 0xc4, 0x53,
|
||||
0x8f, 0xa3, 0x7a, 0x89, 0x5e, 0xa9, 0x8e, 0x31, 0x41, 0xd2, 0x2e, 0xb9, 0x58, 0x50, 0xa6, 0x35,
|
||||
0x35, 0x4c, 0x6f, 0x1f, 0x9c, 0x3b, 0xc1, 0x51, 0x85, 0x64, 0x07, 0xea, 0x4b, 0x26, 0x17, 0xc2,
|
||||
0xb5, 0xba, 0x56, 0xbf, 0x15, 0x98, 0x8f, 0xde, 0xb7, 0x0d, 0xce, 0x38, 0x47, 0xc9, 0x19, 0x34,
|
||||
0x92, 0x68, 0x2e, 0x70, 0x91, 0xe4, 0x48, 0xd3, 0x3f, 0xa0, 0x2b, 0x9d, 0xd4, 0xf0, 0xf4, 0xde,
|
||||
0xc0, 0x41, 0x99, 0xea, 0x7c, 0xd8, 0xd0, 0x28, 0x86, 0xe4, 0x04, 0xfe, 0x4e, 0x99, 0x0a, 0xdf,
|
||||
0xa6, 0x6c, 0x26, 0x0a, 0xdd, 0xde, 0x1a, 0x9d, 0xd9, 0x2f, 0xa8, 0x78, 0x72, 0x05, 0x5b, 0x1c,
|
||||
0x95, 0x12, 0x3c, 0x89, 0x50, 0x4d, 0xa2, 0x50, 0x0a, 0xd7, 0xde, 0x44, 0xf1, 0xbf, 0x4a, 0x5d,
|
||||
0x87, 0x52, 0x90, 0x21, 0x34, 0x17, 0x5a, 0x46, 0x6a, 0x36, 0x41, 0x25, 0x53, 0xb7, 0xb6, 0x89,
|
||||
0x03, 0x4c, 0xe2, 0x56, 0xc9, 0x94, 0x8c, 0xa0, 0x15, 0xe2, 0xbb, 0xaa, 0x0c, 0x7f, 0x36, 0x31,
|
||||
0xfc, 0x2b, 0x33, 0x99, 0xa3, 0xf7, 0x69, 0x81, 0x73, 0x91, 0x17, 0x45, 0x86, 0x50, 0x97, 0x62,
|
||||
0x29, 0xa4, 0x6b, 0x75, 0x6b, 0xfd, 0xa6, 0xdf, 0x5f, 0xa3, 0x31, 0x34, 0xbd, 0xc9, 0xd0, 0x4b,
|
||||
0x95, 0xc4, 0x69, 0x60, 0x62, 0x9d, 0x47, 0x80, 0x6a, 0x48, 0xb6, 0xa1, 0x36, 0x13, 0x69, 0xd1,
|
||||
0x66, 0xf6, 0x24, 0x83, 0xb2, 0xe1, 0xdf, 0x7f, 0x96, 0xa9, 0xaf, 0x38, 0x80, 0x63, 0xfb, 0xc8,
|
||||
0x1a, 0x9d, 0xc2, 0x2e, 0xc7, 0xf9, 0x6a, 0x7c, 0x6c, 0x3d, 0x39, 0xe6, 0xf5, 0x65, 0xb7, 0x1f,
|
||||
0xfc, 0x80, 0x65, 0x0b, 0xc6, 0x82, 0x9e, 0x6b, 0x5d, 0x98, 0x9e, 0x9d, 0xfc, 0x02, 0x07, 0x3f,
|
||||
0x01, 0x00, 0x00, 0xff, 0xff, 0xcf, 0x25, 0x25, 0xc2, 0xab, 0x02, 0x00, 0x00,
|
||||
}
|
27
app/policy/config.proto
Normal file
27
app/policy/config.proto
Normal file
@ -0,0 +1,27 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package v2ray.core.app.policy;
|
||||
option csharp_namespace = "V2Ray.Core.App.Policy";
|
||||
option go_package = "policy";
|
||||
option java_package = "com.v2ray.core.app.policy";
|
||||
option java_multiple_files = true;
|
||||
|
||||
message Second {
|
||||
uint32 value = 1;
|
||||
}
|
||||
|
||||
message Policy {
|
||||
// Timeout is a message for timeout settings in various stages, in seconds.
|
||||
message Timeout {
|
||||
Second handshake = 1;
|
||||
Second connection_idle = 2;
|
||||
Second uplink_only = 3;
|
||||
Second downlink_only = 4;
|
||||
}
|
||||
|
||||
Timeout timeout = 1;
|
||||
}
|
||||
|
||||
message Config {
|
||||
map<uint32, Policy> level = 1;
|
||||
}
|
66
app/policy/manager/manager.go
Normal file
66
app/policy/manager/manager.go
Normal file
@ -0,0 +1,66 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"v2ray.com/core/app/policy"
|
||||
"v2ray.com/core/common"
|
||||
)
|
||||
|
||||
type Instance struct {
|
||||
levels map[uint32]*policy.Policy
|
||||
}
|
||||
|
||||
func New(ctx context.Context, config *policy.Config) (*Instance, error) {
|
||||
levels := config.Level
|
||||
if levels == nil {
|
||||
levels = make(map[uint32]*policy.Policy)
|
||||
}
|
||||
for _, p := range levels {
|
||||
g := global()
|
||||
g.OverrideWith(p)
|
||||
*p = g
|
||||
}
|
||||
return &Instance{
|
||||
levels: levels,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func global() policy.Policy {
|
||||
return policy.Policy{
|
||||
Timeout: &policy.Policy_Timeout{
|
||||
Handshake: &policy.Second{Value: 4},
|
||||
ConnectionIdle: &policy.Second{Value: 300},
|
||||
UplinkOnly: &policy.Second{Value: 5},
|
||||
DownlinkOnly: &policy.Second{Value: 30},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetPolicy implements policy.Manager.
|
||||
func (m *Instance) GetPolicy(level uint32) policy.Policy {
|
||||
if p, ok := m.levels[level]; ok {
|
||||
return *p
|
||||
}
|
||||
return global()
|
||||
}
|
||||
|
||||
// Start implements app.Application.Start().
|
||||
func (m *Instance) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close implements app.Application.Close().
|
||||
func (m *Instance) Close() {
|
||||
}
|
||||
|
||||
// Interface implement app.Application.Interface().
|
||||
func (m *Instance) Interface() interface{} {
|
||||
return (*policy.Manager)(nil)
|
||||
}
|
||||
|
||||
func init() {
|
||||
common.Must(common.RegisterConfig((*policy.Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
|
||||
return New(ctx, config.(*policy.Config))
|
||||
}))
|
||||
}
|
20
app/policy/policy.go
Normal file
20
app/policy/policy.go
Normal file
@ -0,0 +1,20 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"v2ray.com/core/app"
|
||||
)
|
||||
|
||||
// Manager is an utility to manage policy per user level.
|
||||
type Manager interface {
|
||||
// GetPolicy returns the Policy for the given user level.
|
||||
GetPolicy(level uint32) Policy
|
||||
}
|
||||
|
||||
// FromSpace returns the policy.Manager in a space.
|
||||
func FromSpace(space app.Space) Manager {
|
||||
app := space.GetApplication((*Manager)(nil))
|
||||
if app == nil {
|
||||
return nil
|
||||
}
|
||||
return app.(Manager)
|
||||
}
|
@ -172,6 +172,11 @@ func (*udpConn) SetWriteDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type connId struct {
|
||||
src net.Destination
|
||||
dest net.Destination
|
||||
}
|
||||
|
||||
type udpWorker struct {
|
||||
sync.RWMutex
|
||||
|
||||
@ -185,39 +190,43 @@ type udpWorker struct {
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
activeConn map[net.Destination]*udpConn
|
||||
activeConn map[connId]*udpConn
|
||||
}
|
||||
|
||||
func (w *udpWorker) getConnection(src net.Destination) (*udpConn, bool) {
|
||||
func (w *udpWorker) getConnection(id connId) (*udpConn, bool) {
|
||||
w.Lock()
|
||||
defer w.Unlock()
|
||||
|
||||
if conn, found := w.activeConn[src]; found {
|
||||
if conn, found := w.activeConn[id]; found {
|
||||
return conn, true
|
||||
}
|
||||
|
||||
conn := &udpConn{
|
||||
input: make(chan *buf.Buffer, 32),
|
||||
output: func(b []byte) (int, error) {
|
||||
return w.hub.WriteTo(b, src)
|
||||
return w.hub.WriteTo(b, id.src)
|
||||
},
|
||||
remote: &net.UDPAddr{
|
||||
IP: src.Address.IP(),
|
||||
Port: int(src.Port),
|
||||
IP: id.src.Address.IP(),
|
||||
Port: int(id.src.Port),
|
||||
},
|
||||
local: &net.UDPAddr{
|
||||
IP: w.address.IP(),
|
||||
Port: int(w.port),
|
||||
},
|
||||
}
|
||||
w.activeConn[src] = conn
|
||||
w.activeConn[id] = conn
|
||||
|
||||
conn.updateActivity()
|
||||
return conn, false
|
||||
}
|
||||
|
||||
func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest net.Destination) {
|
||||
conn, existing := w.getConnection(source)
|
||||
id := connId{
|
||||
src: source,
|
||||
dest: originalDest,
|
||||
}
|
||||
conn, existing := w.getConnection(id)
|
||||
select {
|
||||
case conn.input <- b:
|
||||
default:
|
||||
@ -240,20 +249,20 @@ func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest
|
||||
if err := w.proxy.Process(ctx, net.Network_UDP, conn, w.dispatcher); err != nil {
|
||||
log.Trace(newError("connection ends").Base(err))
|
||||
}
|
||||
w.removeConn(source)
|
||||
w.removeConn(id)
|
||||
cancel()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *udpWorker) removeConn(src net.Destination) {
|
||||
func (w *udpWorker) removeConn(id connId) {
|
||||
w.Lock()
|
||||
delete(w.activeConn, src)
|
||||
delete(w.activeConn, id)
|
||||
w.Unlock()
|
||||
}
|
||||
|
||||
func (w *udpWorker) Start() error {
|
||||
w.activeConn = make(map[net.Destination]*udpConn)
|
||||
w.activeConn = make(map[connId]*udpConn, 16)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w.ctx = ctx
|
||||
w.cancel = cancel
|
||||
|
@ -1,8 +1,10 @@
|
||||
package mux
|
||||
|
||||
import (
|
||||
"v2ray.com/core/common/bitmask"
|
||||
"v2ray.com/core/common/buf"
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/protocol"
|
||||
"v2ray.com/core/common/serial"
|
||||
)
|
||||
|
||||
@ -15,24 +17,10 @@ const (
|
||||
SessionStatusKeepAlive SessionStatus = 0x04
|
||||
)
|
||||
|
||||
type Option byte
|
||||
|
||||
const (
|
||||
OptionData Option = 0x01
|
||||
OptionData bitmask.Byte = 0x01
|
||||
)
|
||||
|
||||
func (o Option) Has(x Option) bool {
|
||||
return (o & x) == x
|
||||
}
|
||||
|
||||
func (o *Option) Add(x Option) {
|
||||
*o = (*o | x)
|
||||
}
|
||||
|
||||
func (o *Option) Clear(x Option) {
|
||||
*o = (*o & (^x))
|
||||
}
|
||||
|
||||
type TargetNetwork byte
|
||||
|
||||
const (
|
||||
@ -40,14 +28,6 @@ const (
|
||||
TargetNetworkUDP TargetNetwork = 0x02
|
||||
)
|
||||
|
||||
type AddressType byte
|
||||
|
||||
const (
|
||||
AddressTypeIPv4 AddressType = 0x01
|
||||
AddressTypeDomain AddressType = 0x02
|
||||
AddressTypeIPv6 AddressType = 0x03
|
||||
)
|
||||
|
||||
/*
|
||||
Frame format
|
||||
2 bytes - length
|
||||
@ -62,10 +42,10 @@ n bytes - address
|
||||
*/
|
||||
|
||||
type FrameMetadata struct {
|
||||
SessionID uint16
|
||||
SessionStatus SessionStatus
|
||||
Target net.Destination
|
||||
Option Option
|
||||
SessionID uint16
|
||||
Option bitmask.Byte
|
||||
SessionStatus SessionStatus
|
||||
}
|
||||
|
||||
func (f FrameMetadata) AsSupplier() buf.Supplier {
|
||||
@ -92,17 +72,21 @@ func (f FrameMetadata) AsSupplier() buf.Supplier {
|
||||
addr := f.Target.Address
|
||||
switch addr.Family() {
|
||||
case net.AddressFamilyIPv4:
|
||||
b = append(b, byte(AddressTypeIPv4))
|
||||
b = append(b, byte(protocol.AddressTypeIPv4))
|
||||
b = append(b, addr.IP()...)
|
||||
length += 5
|
||||
case net.AddressFamilyIPv6:
|
||||
b = append(b, byte(AddressTypeIPv6))
|
||||
b = append(b, byte(protocol.AddressTypeIPv6))
|
||||
b = append(b, addr.IP()...)
|
||||
length += 17
|
||||
case net.AddressFamilyDomain:
|
||||
nDomain := len(addr.Domain())
|
||||
b = append(b, byte(AddressTypeDomain), byte(nDomain))
|
||||
b = append(b, addr.Domain()...)
|
||||
domain := addr.Domain()
|
||||
if protocol.IsDomainTooLong(domain) {
|
||||
return 0, newError("domain name too long: ", domain)
|
||||
}
|
||||
nDomain := len(domain)
|
||||
b = append(b, byte(protocol.AddressTypeDomain), byte(nDomain))
|
||||
b = append(b, domain...)
|
||||
length += nDomain + 2
|
||||
}
|
||||
}
|
||||
@ -120,7 +104,7 @@ func ReadFrameFrom(b []byte) (*FrameMetadata, error) {
|
||||
f := &FrameMetadata{
|
||||
SessionID: serial.BytesToUint16(b[:2]),
|
||||
SessionStatus: SessionStatus(b[2]),
|
||||
Option: Option(b[3]),
|
||||
Option: bitmask.Byte(b[3]),
|
||||
}
|
||||
|
||||
b = b[4:]
|
||||
@ -128,18 +112,18 @@ func ReadFrameFrom(b []byte) (*FrameMetadata, error) {
|
||||
if f.SessionStatus == SessionStatusNew {
|
||||
network := TargetNetwork(b[0])
|
||||
port := net.PortFromBytes(b[1:3])
|
||||
addrType := AddressType(b[3])
|
||||
addrType := protocol.AddressType(b[3])
|
||||
b = b[4:]
|
||||
|
||||
var addr net.Address
|
||||
switch addrType {
|
||||
case AddressTypeIPv4:
|
||||
case protocol.AddressTypeIPv4:
|
||||
addr = net.IPAddress(b[0:4])
|
||||
b = b[4:]
|
||||
case AddressTypeIPv6:
|
||||
case protocol.AddressTypeIPv6:
|
||||
addr = net.IPAddress(b[0:16])
|
||||
b = b[16:]
|
||||
case AddressTypeDomain:
|
||||
case protocol.AddressTypeDomain:
|
||||
nDomain := int(b[0])
|
||||
addr = net.DomainAddress(string(b[1 : 1+nDomain]))
|
||||
b = b[nDomain+1:]
|
||||
|
@ -90,7 +90,19 @@ func NewClient(p proxy.Outbound, dialer proxy.Dialer, m *ClientManager) (*Client
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx = proxy.ContextWithTarget(ctx, net.TCPDestination(muxCoolAddress, muxCoolPort))
|
||||
pipe := ray.NewRay(ctx)
|
||||
go p.Process(ctx, pipe, dialer)
|
||||
|
||||
go func() {
|
||||
if err := p.Process(ctx, pipe, dialer); err != nil {
|
||||
cancel()
|
||||
|
||||
traceErr := errors.New("failed to handler mux client connection").Base(err)
|
||||
if err != io.EOF && err != context.Canceled {
|
||||
traceErr = traceErr.AtWarning()
|
||||
}
|
||||
log.Trace(traceErr)
|
||||
}
|
||||
}()
|
||||
|
||||
c := &Client{
|
||||
sessionManager: NewSessionManager(),
|
||||
inboundRay: pipe,
|
||||
@ -104,6 +116,7 @@ func NewClient(p proxy.Outbound, dialer proxy.Dialer, m *ClientManager) (*Client
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Closed returns true if this Client is closed.
|
||||
func (m *Client) Closed() bool {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
@ -148,7 +161,7 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
|
||||
|
||||
log.Trace(newError("dispatching request to ", dest))
|
||||
data, _ := s.input.ReadTimeout(time.Millisecond * 500)
|
||||
if err := writer.Write(data); err != nil {
|
||||
if err := writer.WriteMultiBuffer(data); err != nil {
|
||||
log.Trace(newError("failed to write first payload").Base(err))
|
||||
return
|
||||
}
|
||||
@ -179,26 +192,25 @@ func (m *Client) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) bool
|
||||
return true
|
||||
}
|
||||
|
||||
func drain(reader io.Reader) error {
|
||||
buf.Copy(NewStreamReader(reader), buf.Discard)
|
||||
return nil
|
||||
func drain(reader *buf.BufferedReader) error {
|
||||
return buf.Copy(NewStreamReader(reader), buf.Discard)
|
||||
}
|
||||
|
||||
func (m *Client) handleStatueKeepAlive(meta *FrameMetadata, reader io.Reader) error {
|
||||
func (m *Client) handleStatueKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error {
|
||||
if meta.Option.Has(OptionData) {
|
||||
return drain(reader)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Client) handleStatusNew(meta *FrameMetadata, reader io.Reader) error {
|
||||
func (m *Client) handleStatusNew(meta *FrameMetadata, reader *buf.BufferedReader) error {
|
||||
if meta.Option.Has(OptionData) {
|
||||
return drain(reader)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Client) handleStatusKeep(meta *FrameMetadata, reader io.Reader) error {
|
||||
func (m *Client) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error {
|
||||
if !meta.Option.Has(OptionData) {
|
||||
return nil
|
||||
}
|
||||
@ -209,7 +221,7 @@ func (m *Client) handleStatusKeep(meta *FrameMetadata, reader io.Reader) error {
|
||||
return drain(reader)
|
||||
}
|
||||
|
||||
func (m *Client) handleStatusEnd(meta *FrameMetadata, reader io.Reader) error {
|
||||
func (m *Client) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error {
|
||||
if s, found := m.sessionManager.Get(meta.SessionID); found {
|
||||
s.Close()
|
||||
}
|
||||
@ -222,11 +234,10 @@ func (m *Client) handleStatusEnd(meta *FrameMetadata, reader io.Reader) error {
|
||||
func (m *Client) fetchOutput() {
|
||||
defer m.cancel()
|
||||
|
||||
reader := buf.ToBytesReader(m.inboundRay.InboundOutput())
|
||||
metaReader := NewMetadataReader(reader)
|
||||
reader := buf.NewBufferedReader(m.inboundRay.InboundOutput())
|
||||
|
||||
for {
|
||||
meta, err := metaReader.Read()
|
||||
meta, err := ReadMetadata(reader)
|
||||
if err != nil {
|
||||
if errors.Cause(err) != io.EOF {
|
||||
log.Trace(newError("failed to read metadata").Base(err))
|
||||
@ -263,7 +274,7 @@ type Server struct {
|
||||
func NewServer(ctx context.Context) *Server {
|
||||
s := &Server{}
|
||||
space := app.SpaceFromContext(ctx)
|
||||
space.OnInitialize(func() error {
|
||||
space.On(app.SpaceInitializing, func(interface{}) error {
|
||||
d := dispatcher.FromSpace(space)
|
||||
if d == nil {
|
||||
return newError("no dispatcher in space")
|
||||
@ -304,14 +315,14 @@ func handle(ctx context.Context, s *Session, output buf.Writer) {
|
||||
s.Close()
|
||||
}
|
||||
|
||||
func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader io.Reader) error {
|
||||
func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error {
|
||||
if meta.Option.Has(OptionData) {
|
||||
return drain(reader)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, reader io.Reader) error {
|
||||
func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, reader *buf.BufferedReader) error {
|
||||
log.Trace(newError("received request for ", meta.Target))
|
||||
inboundRay, err := w.dispatcher.Dispatch(ctx, meta.Target)
|
||||
if err != nil {
|
||||
@ -338,7 +349,7 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader io.Reader) error {
|
||||
func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error {
|
||||
if !meta.Option.Has(OptionData) {
|
||||
return nil
|
||||
}
|
||||
@ -348,7 +359,7 @@ func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader io.Reader) e
|
||||
return drain(reader)
|
||||
}
|
||||
|
||||
func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader io.Reader) error {
|
||||
func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error {
|
||||
if s, found := w.sessionManager.Get(meta.SessionID); found {
|
||||
s.Close()
|
||||
}
|
||||
@ -358,9 +369,8 @@ func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader io.Reader) er
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *ServerWorker) handleFrame(ctx context.Context, reader io.Reader) error {
|
||||
metaReader := NewMetadataReader(reader)
|
||||
meta, err := metaReader.Read()
|
||||
func (w *ServerWorker) handleFrame(ctx context.Context, reader *buf.BufferedReader) error {
|
||||
meta, err := ReadMetadata(reader)
|
||||
if err != nil {
|
||||
return newError("failed to read metadata").Base(err)
|
||||
}
|
||||
@ -386,7 +396,7 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader io.Reader) error
|
||||
|
||||
func (w *ServerWorker) run(ctx context.Context) {
|
||||
input := w.outboundRay.OutboundInput()
|
||||
reader := buf.ToBytesReader(input)
|
||||
reader := buf.NewBufferedReader(input)
|
||||
|
||||
defer w.sessionManager.Close()
|
||||
|
||||
|
@ -9,14 +9,14 @@ import (
|
||||
"v2ray.com/core/common/buf"
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/protocol"
|
||||
"v2ray.com/core/testing/assert"
|
||||
"v2ray.com/core/transport/ray"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func readAll(reader buf.Reader) (buf.MultiBuffer, error) {
|
||||
mb := buf.NewMultiBuffer()
|
||||
var mb buf.MultiBuffer
|
||||
for {
|
||||
b, err := reader.Read()
|
||||
b, err := reader.ReadMultiBuffer()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
@ -29,7 +29,7 @@ func readAll(reader buf.Reader) (buf.MultiBuffer, error) {
|
||||
}
|
||||
|
||||
func TestReaderWriter(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
stream := ray.NewStream(context.Background())
|
||||
|
||||
@ -45,98 +45,98 @@ func TestReaderWriter(t *testing.T) {
|
||||
writePayload := func(writer *Writer, payload ...byte) error {
|
||||
b := buf.New()
|
||||
b.Append(payload)
|
||||
return writer.Write(buf.NewMultiBufferValue(b))
|
||||
return writer.WriteMultiBuffer(buf.NewMultiBufferValue(b))
|
||||
}
|
||||
|
||||
assert.Error(writePayload(writer, 'a', 'b', 'c', 'd')).IsNil()
|
||||
assert.Error(writePayload(writer2)).IsNil()
|
||||
assert(writePayload(writer, 'a', 'b', 'c', 'd'), IsNil)
|
||||
assert(writePayload(writer2), IsNil)
|
||||
|
||||
assert.Error(writePayload(writer, 'e', 'f', 'g', 'h')).IsNil()
|
||||
assert.Error(writePayload(writer3, 'x')).IsNil()
|
||||
assert(writePayload(writer, 'e', 'f', 'g', 'h'), IsNil)
|
||||
assert(writePayload(writer3, 'x'), IsNil)
|
||||
|
||||
writer.Close()
|
||||
writer3.Close()
|
||||
|
||||
assert.Error(writePayload(writer2, 'y')).IsNil()
|
||||
assert(writePayload(writer2, 'y'), IsNil)
|
||||
writer2.Close()
|
||||
|
||||
bytesReader := buf.ToBytesReader(stream)
|
||||
metaReader := NewMetadataReader(bytesReader)
|
||||
bytesReader := buf.NewBufferedReader(stream)
|
||||
streamReader := NewStreamReader(bytesReader)
|
||||
|
||||
meta, err := metaReader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.Uint16(meta.SessionID).Equals(1)
|
||||
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusNew))
|
||||
assert.Destination(meta.Target).Equals(dest)
|
||||
assert.Byte(byte(meta.Option)).Equals(byte(OptionData))
|
||||
meta, err := ReadMetadata(bytesReader)
|
||||
assert(err, IsNil)
|
||||
assert(meta.SessionID, Equals, uint16(1))
|
||||
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusNew))
|
||||
assert(meta.Target, Equals, dest)
|
||||
assert(byte(meta.Option), Equals, byte(OptionData))
|
||||
|
||||
data, err := readAll(streamReader)
|
||||
assert.Error(err).IsNil()
|
||||
assert.Int(len(data)).Equals(1)
|
||||
assert.String(data[0].String()).Equals("abcd")
|
||||
assert(err, IsNil)
|
||||
assert(len(data), Equals, 1)
|
||||
assert(data[0].String(), Equals, "abcd")
|
||||
|
||||
meta, err = metaReader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusNew))
|
||||
assert.Uint16(meta.SessionID).Equals(2)
|
||||
assert.Byte(byte(meta.Option)).Equals(0)
|
||||
assert.Destination(meta.Target).Equals(dest2)
|
||||
meta, err = ReadMetadata(bytesReader)
|
||||
assert(err, IsNil)
|
||||
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusNew))
|
||||
assert(meta.SessionID, Equals, uint16(2))
|
||||
assert(byte(meta.Option), Equals, byte(0))
|
||||
assert(meta.Target, Equals, dest2)
|
||||
|
||||
meta, err = metaReader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusKeep))
|
||||
assert.Uint16(meta.SessionID).Equals(1)
|
||||
assert.Byte(byte(meta.Option)).Equals(1)
|
||||
meta, err = ReadMetadata(bytesReader)
|
||||
assert(err, IsNil)
|
||||
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusKeep))
|
||||
assert(meta.SessionID, Equals, uint16(1))
|
||||
assert(byte(meta.Option), Equals, byte(1))
|
||||
|
||||
data, err = readAll(streamReader)
|
||||
assert.Error(err).IsNil()
|
||||
assert.Int(len(data)).Equals(1)
|
||||
assert.String(data[0].String()).Equals("efgh")
|
||||
assert(err, IsNil)
|
||||
assert(len(data), Equals, 1)
|
||||
assert(data[0].String(), Equals, "efgh")
|
||||
|
||||
meta, err = metaReader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusNew))
|
||||
assert.Uint16(meta.SessionID).Equals(3)
|
||||
assert.Byte(byte(meta.Option)).Equals(1)
|
||||
assert.Destination(meta.Target).Equals(dest3)
|
||||
meta, err = ReadMetadata(bytesReader)
|
||||
assert(err, IsNil)
|
||||
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusNew))
|
||||
assert(meta.SessionID, Equals, uint16(3))
|
||||
assert(byte(meta.Option), Equals, byte(1))
|
||||
assert(meta.Target, Equals, dest3)
|
||||
|
||||
data, err = readAll(streamReader)
|
||||
assert.Error(err).IsNil()
|
||||
assert.Int(len(data)).Equals(1)
|
||||
assert.String(data[0].String()).Equals("x")
|
||||
assert(err, IsNil)
|
||||
assert(len(data), Equals, 1)
|
||||
assert(data[0].String(), Equals, "x")
|
||||
|
||||
meta, err = metaReader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusEnd))
|
||||
assert.Uint16(meta.SessionID).Equals(1)
|
||||
assert.Byte(byte(meta.Option)).Equals(0)
|
||||
meta, err = ReadMetadata(bytesReader)
|
||||
assert(err, IsNil)
|
||||
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusEnd))
|
||||
assert(meta.SessionID, Equals, uint16(1))
|
||||
assert(byte(meta.Option), Equals, byte(0))
|
||||
|
||||
meta, err = metaReader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusEnd))
|
||||
assert.Uint16(meta.SessionID).Equals(3)
|
||||
assert.Byte(byte(meta.Option)).Equals(0)
|
||||
meta, err = ReadMetadata(bytesReader)
|
||||
assert(err, IsNil)
|
||||
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusEnd))
|
||||
assert(meta.SessionID, Equals, uint16(3))
|
||||
assert(byte(meta.Option), Equals, byte(0))
|
||||
|
||||
meta, err = metaReader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusKeep))
|
||||
assert.Uint16(meta.SessionID).Equals(2)
|
||||
assert.Byte(byte(meta.Option)).Equals(1)
|
||||
meta, err = ReadMetadata(bytesReader)
|
||||
assert(err, IsNil)
|
||||
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusKeep))
|
||||
assert(meta.SessionID, Equals, uint16(2))
|
||||
assert(byte(meta.Option), Equals, byte(1))
|
||||
|
||||
data, err = readAll(streamReader)
|
||||
assert.Error(err).IsNil()
|
||||
assert.Int(len(data)).Equals(1)
|
||||
assert.String(data[0].String()).Equals("y")
|
||||
assert(err, IsNil)
|
||||
assert(len(data), Equals, 1)
|
||||
assert(data[0].String(), Equals, "y")
|
||||
|
||||
meta, err = metaReader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusEnd))
|
||||
assert.Uint16(meta.SessionID).Equals(2)
|
||||
assert.Byte(byte(meta.Option)).Equals(0)
|
||||
meta, err = ReadMetadata(bytesReader)
|
||||
assert(err, IsNil)
|
||||
assert(byte(meta.SessionStatus), Equals, byte(SessionStatusEnd))
|
||||
assert(meta.SessionID, Equals, uint16(2))
|
||||
assert(byte(meta.Option), Equals, byte(0))
|
||||
|
||||
stream.Close()
|
||||
|
||||
meta, err = metaReader.Read()
|
||||
assert.Error(err).IsNotNil()
|
||||
meta, err = ReadMetadata(bytesReader)
|
||||
assert(err, IsNotNil)
|
||||
assert(meta, IsNil)
|
||||
}
|
||||
|
@ -7,20 +7,9 @@ import (
|
||||
"v2ray.com/core/common/serial"
|
||||
)
|
||||
|
||||
type MetadataReader struct {
|
||||
reader io.Reader
|
||||
buffer []byte
|
||||
}
|
||||
|
||||
func NewMetadataReader(reader io.Reader) *MetadataReader {
|
||||
return &MetadataReader{
|
||||
reader: reader,
|
||||
buffer: make([]byte, 1024),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MetadataReader) Read() (*FrameMetadata, error) {
|
||||
metaLen, err := serial.ReadUint16(r.reader)
|
||||
// ReadMetadata reads FrameMetadata from the given reader.
|
||||
func ReadMetadata(reader io.Reader) (*FrameMetadata, error) {
|
||||
metaLen, err := serial.ReadUint16(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -28,17 +17,22 @@ func (r *MetadataReader) Read() (*FrameMetadata, error) {
|
||||
return nil, newError("invalid metalen ", metaLen).AtWarning()
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(r.reader, r.buffer[:metaLen]); err != nil {
|
||||
b := buf.New()
|
||||
defer b.Release()
|
||||
|
||||
if err := b.Reset(buf.ReadFullFrom(reader, int(metaLen))); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ReadFrameFrom(r.buffer)
|
||||
return ReadFrameFrom(b.Bytes())
|
||||
}
|
||||
|
||||
// PacketReader is an io.Reader that reads whole chunk of Mux frames every time.
|
||||
type PacketReader struct {
|
||||
reader io.Reader
|
||||
eof bool
|
||||
}
|
||||
|
||||
// NewPacketReader creates a new PacketReader.
|
||||
func NewPacketReader(reader io.Reader) *PacketReader {
|
||||
return &PacketReader{
|
||||
reader: reader,
|
||||
@ -46,7 +40,8 @@ func NewPacketReader(reader io.Reader) *PacketReader {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *PacketReader) Read() (buf.MultiBuffer, error) {
|
||||
// ReadMultiBuffer implements buf.Reader.
|
||||
func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
|
||||
if r.eof {
|
||||
return nil, io.EOF
|
||||
}
|
||||
@ -70,19 +65,22 @@ func (r *PacketReader) Read() (buf.MultiBuffer, error) {
|
||||
return buf.NewMultiBufferValue(b), nil
|
||||
}
|
||||
|
||||
// StreamReader reads Mux frame as a stream.
|
||||
type StreamReader struct {
|
||||
reader io.Reader
|
||||
reader *buf.BufferedReader
|
||||
leftOver int
|
||||
}
|
||||
|
||||
func NewStreamReader(reader io.Reader) *StreamReader {
|
||||
// NewStreamReader creates a new StreamReader.
|
||||
func NewStreamReader(reader *buf.BufferedReader) *StreamReader {
|
||||
return &StreamReader{
|
||||
reader: reader,
|
||||
leftOver: -1,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *StreamReader) Read() (buf.MultiBuffer, error) {
|
||||
// ReadMultiBuffer implmenets buf.Reader.
|
||||
func (r *StreamReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
|
||||
if r.leftOver == 0 {
|
||||
r.leftOver = -1
|
||||
return nil, io.EOF
|
||||
@ -96,24 +94,7 @@ func (r *StreamReader) Read() (buf.MultiBuffer, error) {
|
||||
r.leftOver = int(size)
|
||||
}
|
||||
|
||||
mb := buf.NewMultiBuffer()
|
||||
for r.leftOver > 0 {
|
||||
readLen := buf.Size
|
||||
if r.leftOver < readLen {
|
||||
readLen = r.leftOver
|
||||
}
|
||||
b := buf.New()
|
||||
if err := b.AppendSupplier(func(bb []byte) (int, error) {
|
||||
return r.reader.Read(bb[:readLen])
|
||||
}); err != nil {
|
||||
mb.Release()
|
||||
return nil, err
|
||||
}
|
||||
r.leftOver -= b.Len()
|
||||
mb.Append(b)
|
||||
if b.Len() < readLen {
|
||||
break
|
||||
}
|
||||
}
|
||||
return mb, nil
|
||||
mb, err := r.reader.ReadAtMost(r.leftOver)
|
||||
r.leftOver -= mb.Len()
|
||||
return mb, err
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
package mux
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"v2ray.com/core/common/buf"
|
||||
@ -19,7 +18,7 @@ type SessionManager struct {
|
||||
func NewSessionManager() *SessionManager {
|
||||
return &SessionManager{
|
||||
count: 0,
|
||||
sessions: make(map[uint16]*Session, 32),
|
||||
sessions: make(map[uint16]*Session, 16),
|
||||
}
|
||||
}
|
||||
|
||||
@ -58,6 +57,10 @@ func (m *SessionManager) Add(s *Session) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
if m.closed {
|
||||
return
|
||||
}
|
||||
|
||||
m.sessions[s.ID] = s
|
||||
}
|
||||
|
||||
@ -65,6 +68,10 @@ func (m *SessionManager) Remove(id uint16) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
if m.closed {
|
||||
return
|
||||
}
|
||||
|
||||
delete(m.sessions, id)
|
||||
}
|
||||
|
||||
@ -111,9 +118,10 @@ func (m *SessionManager) Close() {
|
||||
s.output.Close()
|
||||
}
|
||||
|
||||
m.sessions = make(map[uint16]*Session)
|
||||
m.sessions = nil
|
||||
}
|
||||
|
||||
// Session represents a client connection in a Mux connection.
|
||||
type Session struct {
|
||||
input ray.InputStream
|
||||
output ray.OutputStream
|
||||
@ -122,13 +130,15 @@ type Session struct {
|
||||
transferType protocol.TransferType
|
||||
}
|
||||
|
||||
// Close closes all resources associated with this session.
|
||||
func (s *Session) Close() {
|
||||
s.output.Close()
|
||||
s.input.Close()
|
||||
s.parent.Remove(s.ID)
|
||||
}
|
||||
|
||||
func (s *Session) NewReader(reader io.Reader) buf.Reader {
|
||||
// NewReader creates a buf.Reader based on the transfer type of this Session.
|
||||
func (s *Session) NewReader(reader *buf.BufferedReader) buf.Reader {
|
||||
if s.transferType == protocol.TransferTypeStream {
|
||||
return NewStreamReader(reader)
|
||||
}
|
||||
|
@ -4,34 +4,36 @@ import (
|
||||
"testing"
|
||||
|
||||
. "v2ray.com/core/app/proxyman/mux"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestSessionManagerAdd(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
m := NewSessionManager()
|
||||
|
||||
s := m.Allocate()
|
||||
assert.Uint16(s.ID).Equals(1)
|
||||
assert(s.ID, Equals, uint16(1))
|
||||
assert(m.Size(), Equals, 1)
|
||||
|
||||
s = m.Allocate()
|
||||
assert.Uint16(s.ID).Equals(2)
|
||||
assert(s.ID, Equals, uint16(2))
|
||||
assert(m.Size(), Equals, 2)
|
||||
|
||||
s = &Session{
|
||||
ID: 4,
|
||||
}
|
||||
m.Add(s)
|
||||
assert.Uint16(s.ID).Equals(4)
|
||||
assert(s.ID, Equals, uint16(4))
|
||||
}
|
||||
|
||||
func TestSessionManagerClose(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
m := NewSessionManager()
|
||||
s := m.Allocate()
|
||||
|
||||
assert.Bool(m.CloseIfNoSession()).IsFalse()
|
||||
assert(m.CloseIfNoSession(), IsFalse)
|
||||
m.Remove(s.ID)
|
||||
assert.Bool(m.CloseIfNoSession()).IsTrue()
|
||||
assert(m.CloseIfNoSession(), IsTrue)
|
||||
}
|
||||
|
@ -1,8 +1,7 @@
|
||||
package mux
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
|
||||
"v2ray.com/core/common"
|
||||
"v2ray.com/core/common/buf"
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/protocol"
|
||||
@ -10,9 +9,9 @@ import (
|
||||
)
|
||||
|
||||
type Writer struct {
|
||||
id uint16
|
||||
dest net.Destination
|
||||
writer buf.Writer
|
||||
id uint16
|
||||
followup bool
|
||||
transferType protocol.TransferType
|
||||
}
|
||||
@ -54,51 +53,47 @@ func (w *Writer) getNextFrameMeta() FrameMetadata {
|
||||
func (w *Writer) writeMetaOnly() error {
|
||||
meta := w.getNextFrameMeta()
|
||||
b := buf.New()
|
||||
if err := b.AppendSupplier(meta.AsSupplier()); err != nil {
|
||||
if err := b.Reset(meta.AsSupplier()); err != nil {
|
||||
return err
|
||||
}
|
||||
runtime.KeepAlive(meta)
|
||||
return w.writer.Write(buf.NewMultiBufferValue(b))
|
||||
return w.writer.WriteMultiBuffer(buf.NewMultiBufferValue(b))
|
||||
}
|
||||
|
||||
func (w *Writer) writeData(mb buf.MultiBuffer) error {
|
||||
meta := w.getNextFrameMeta()
|
||||
meta.Option.Add(OptionData)
|
||||
meta.Option.Set(OptionData)
|
||||
|
||||
frame := buf.New()
|
||||
if err := frame.AppendSupplier(meta.AsSupplier()); err != nil {
|
||||
if err := frame.Reset(meta.AsSupplier()); err != nil {
|
||||
return err
|
||||
}
|
||||
runtime.KeepAlive(meta)
|
||||
if err := frame.AppendSupplier(serial.WriteUint16(uint16(mb.Len()))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mb2 := buf.NewMultiBuffer()
|
||||
mb2 := buf.NewMultiBufferCap(len(mb) + 1)
|
||||
mb2.Append(frame)
|
||||
mb2.AppendMulti(mb)
|
||||
return w.writer.Write(mb2)
|
||||
return w.writer.WriteMultiBuffer(mb2)
|
||||
}
|
||||
|
||||
// Write implements buf.MultiBufferWriter.
|
||||
func (w *Writer) Write(mb buf.MultiBuffer) error {
|
||||
// WriteMultiBuffer implements buf.Writer.
|
||||
func (w *Writer) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
||||
defer mb.Release()
|
||||
|
||||
if mb.IsEmpty() {
|
||||
return w.writeMetaOnly()
|
||||
}
|
||||
|
||||
if w.transferType == protocol.TransferTypeStream {
|
||||
const chunkSize = 8 * 1024
|
||||
for !mb.IsEmpty() {
|
||||
slice := mb.SliceBySize(chunkSize)
|
||||
if err := w.writeData(slice); err != nil {
|
||||
return err
|
||||
}
|
||||
for !mb.IsEmpty() {
|
||||
var chunk buf.MultiBuffer
|
||||
if w.transferType == protocol.TransferTypeStream {
|
||||
chunk = mb.SliceBySize(8 * 1024)
|
||||
} else {
|
||||
chunk = buf.NewMultiBufferValue(mb.SplitFirst())
|
||||
}
|
||||
} else {
|
||||
for _, b := range mb {
|
||||
if err := w.writeData(buf.NewMultiBufferValue(b)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := w.writeData(chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@ -112,8 +107,7 @@ func (w *Writer) Close() {
|
||||
}
|
||||
|
||||
frame := buf.New()
|
||||
frame.AppendSupplier(meta.AsSupplier())
|
||||
runtime.KeepAlive(meta)
|
||||
common.Must(frame.Reset(meta.AsSupplier()))
|
||||
|
||||
w.writer.Write(buf.NewMultiBufferValue(frame))
|
||||
w.writer.WriteMultiBuffer(buf.NewMultiBufferValue(frame))
|
||||
}
|
||||
|
@ -33,7 +33,7 @@ func NewHandler(ctx context.Context, config *proxyman.OutboundHandlerConfig) (*H
|
||||
if space == nil {
|
||||
return nil, newError("no space in context")
|
||||
}
|
||||
space.OnInitialize(func() error {
|
||||
space.On(app.SpaceInitializing, func(interface{}) error {
|
||||
ohm := proxyman.OutboundHandlerManagerFromSpace(space)
|
||||
if ohm == nil {
|
||||
return newError("no OutboundManager in space")
|
||||
@ -78,6 +78,7 @@ func (h *Handler) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) {
|
||||
err := h.mux.Dispatch(ctx, outboundRay)
|
||||
if err != nil {
|
||||
log.Trace(newError("failed to process outbound traffic").Base(err))
|
||||
outboundRay.OutboundOutput().CloseError()
|
||||
}
|
||||
} else {
|
||||
err := h.proxy.Process(ctx, outboundRay, h)
|
||||
@ -122,8 +123,8 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (internet.Conn
|
||||
}
|
||||
|
||||
var (
|
||||
_ buf.MultiBufferReader = (*Connection)(nil)
|
||||
_ buf.MultiBufferWriter = (*Connection)(nil)
|
||||
_ buf.Reader = (*Connection)(nil)
|
||||
_ buf.Writer = (*Connection)(nil)
|
||||
)
|
||||
|
||||
type Connection struct {
|
||||
@ -132,9 +133,8 @@ type Connection struct {
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
|
||||
bytesReader io.Reader
|
||||
reader buf.Reader
|
||||
writer buf.Writer
|
||||
reader *buf.BufferedReader
|
||||
writer buf.Writer
|
||||
}
|
||||
|
||||
func NewConnection(stream ray.Ray) *Connection {
|
||||
@ -148,9 +148,8 @@ func NewConnection(stream ray.Ray) *Connection {
|
||||
IP: []byte{0, 0, 0, 0},
|
||||
Port: 0,
|
||||
},
|
||||
bytesReader: buf.ToBytesReader(stream.InboundOutput()),
|
||||
reader: stream.InboundOutput(),
|
||||
writer: stream.InboundInput(),
|
||||
reader: buf.NewBufferedReader(stream.InboundOutput()),
|
||||
writer: stream.InboundInput(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -159,11 +158,11 @@ func (v *Connection) Read(b []byte) (int, error) {
|
||||
if v.closed {
|
||||
return 0, io.EOF
|
||||
}
|
||||
return v.bytesReader.Read(b)
|
||||
return v.reader.Read(b)
|
||||
}
|
||||
|
||||
func (v *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) {
|
||||
return v.reader.Read()
|
||||
return v.reader.ReadMultiBuffer()
|
||||
}
|
||||
|
||||
// Write implements net.Conn.Write().
|
||||
@ -171,14 +170,19 @@ func (v *Connection) Write(b []byte) (int, error) {
|
||||
if v.closed {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
return buf.ToBytesWriter(v.writer).Write(b)
|
||||
|
||||
l := len(b)
|
||||
mb := buf.NewMultiBufferCap(l/buf.Size + 1)
|
||||
mb.Write(b)
|
||||
return l, v.writer.WriteMultiBuffer(mb)
|
||||
}
|
||||
|
||||
func (v *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
||||
if v.closed {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
return v.writer.Write(mb)
|
||||
|
||||
return v.writer.WriteMultiBuffer(mb)
|
||||
}
|
||||
|
||||
// Close implements net.Conn.Close().
|
||||
|
@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/protocol"
|
||||
@ -64,13 +66,123 @@ func (v *AnyCondition) Len() int {
|
||||
return len(*v)
|
||||
}
|
||||
|
||||
type PlainDomainMatcher string
|
||||
|
||||
func NewPlainDomainMatcher(pattern string) Condition {
|
||||
return PlainDomainMatcher(pattern)
|
||||
type timedResult struct {
|
||||
timestamp time.Time
|
||||
result bool
|
||||
}
|
||||
|
||||
func (v PlainDomainMatcher) Apply(ctx context.Context) bool {
|
||||
type CachableDomainMatcher struct {
|
||||
sync.Mutex
|
||||
matchers []domainMatcher
|
||||
cache map[string]timedResult
|
||||
lastScan time.Time
|
||||
}
|
||||
|
||||
func NewCachableDomainMatcher() *CachableDomainMatcher {
|
||||
return &CachableDomainMatcher{
|
||||
matchers: make([]domainMatcher, 0, 64),
|
||||
cache: make(map[string]timedResult, 512),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *CachableDomainMatcher) Add(domain *Domain) error {
|
||||
switch domain.Type {
|
||||
case Domain_Plain:
|
||||
m.matchers = append(m.matchers, NewPlainDomainMatcher(domain.Value))
|
||||
case Domain_Regex:
|
||||
rm, err := NewRegexpDomainMatcher(domain.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.matchers = append(m.matchers, rm)
|
||||
case Domain_Domain:
|
||||
m.matchers = append(m.matchers, NewSubDomainMatcher(domain.Value))
|
||||
default:
|
||||
return newError("unknown domain type: ", domain.Type).AtError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *CachableDomainMatcher) applyInternal(domain string) bool {
|
||||
for _, matcher := range m.matchers {
|
||||
if matcher.Apply(domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
type cacheResult int
|
||||
|
||||
const (
|
||||
cacheMiss cacheResult = iota
|
||||
cacheHitTrue
|
||||
cacheHitFalse
|
||||
)
|
||||
|
||||
func (m *CachableDomainMatcher) findInCache(domain string) cacheResult {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
r, f := m.cache[domain]
|
||||
if !f {
|
||||
return cacheMiss
|
||||
}
|
||||
r.timestamp = time.Now()
|
||||
m.cache[domain] = r
|
||||
|
||||
if r.result {
|
||||
return cacheHitTrue
|
||||
}
|
||||
return cacheHitFalse
|
||||
}
|
||||
|
||||
func (m *CachableDomainMatcher) ApplyDomain(domain string) bool {
|
||||
if len(m.matchers) < 64 {
|
||||
return m.applyInternal(domain)
|
||||
}
|
||||
|
||||
cr := m.findInCache(domain)
|
||||
|
||||
if cr == cacheHitTrue {
|
||||
return true
|
||||
}
|
||||
|
||||
if cr == cacheHitFalse {
|
||||
return false
|
||||
}
|
||||
|
||||
r := m.applyInternal(domain)
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
m.cache[domain] = timedResult{
|
||||
result: r,
|
||||
timestamp: time.Now(),
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if len(m.cache) > 256 && now.Sub(m.lastScan)/time.Second > 5 {
|
||||
remove := make([]string, 0, 128)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
for k, v := range m.cache {
|
||||
if now.Sub(v.timestamp)/time.Second > 60 {
|
||||
remove = append(remove, k)
|
||||
}
|
||||
}
|
||||
for _, v := range remove {
|
||||
delete(m.cache, v)
|
||||
}
|
||||
m.lastScan = now
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func (m *CachableDomainMatcher) Apply(ctx context.Context) bool {
|
||||
dest, ok := proxy.TargetFromContext(ctx)
|
||||
if !ok {
|
||||
return false
|
||||
@ -79,7 +191,20 @@ func (v PlainDomainMatcher) Apply(ctx context.Context) bool {
|
||||
if !dest.Address.Family().IsDomain() {
|
||||
return false
|
||||
}
|
||||
domain := dest.Address.Domain()
|
||||
return m.ApplyDomain(dest.Address.Domain())
|
||||
}
|
||||
|
||||
type domainMatcher interface {
|
||||
Apply(domain string) bool
|
||||
}
|
||||
|
||||
type PlainDomainMatcher string
|
||||
|
||||
func NewPlainDomainMatcher(pattern string) PlainDomainMatcher {
|
||||
return PlainDomainMatcher(pattern)
|
||||
}
|
||||
|
||||
func (v PlainDomainMatcher) Apply(domain string) bool {
|
||||
return strings.Contains(domain, string(v))
|
||||
}
|
||||
|
||||
@ -97,33 +222,17 @@ func NewRegexpDomainMatcher(pattern string) (*RegexpDomainMatcher, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (v *RegexpDomainMatcher) Apply(ctx context.Context) bool {
|
||||
dest, ok := proxy.TargetFromContext(ctx)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if !dest.Address.Family().IsDomain() {
|
||||
return false
|
||||
}
|
||||
domain := dest.Address.Domain()
|
||||
func (v *RegexpDomainMatcher) Apply(domain string) bool {
|
||||
return v.pattern.MatchString(strings.ToLower(domain))
|
||||
}
|
||||
|
||||
type SubDomainMatcher string
|
||||
|
||||
func NewSubDomainMatcher(p string) Condition {
|
||||
func NewSubDomainMatcher(p string) SubDomainMatcher {
|
||||
return SubDomainMatcher(p)
|
||||
}
|
||||
|
||||
func (m SubDomainMatcher) Apply(ctx context.Context) bool {
|
||||
dest, ok := proxy.TargetFromContext(ctx)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if !dest.Address.Family().IsDomain() {
|
||||
return false
|
||||
}
|
||||
domain := dest.Address.Domain()
|
||||
func (m SubDomainMatcher) Apply(domain string) bool {
|
||||
pattern := string(m)
|
||||
if !strings.HasSuffix(domain, pattern) {
|
||||
return false
|
||||
@ -149,8 +258,9 @@ func NewCIDRMatcher(ip []byte, mask uint32, onSource bool) (*CIDRMatcher, error)
|
||||
|
||||
func (v *CIDRMatcher) Apply(ctx context.Context) bool {
|
||||
ips := make([]net.IP, 0, 4)
|
||||
if resolveIPs, ok := proxy.ResolvedIPsFromContext(ctx); ok {
|
||||
for _, rip := range resolveIPs {
|
||||
if resolver, ok := proxy.ResolvedIPsFromContext(ctx); ok {
|
||||
resolvedIPs := resolver.Resolve()
|
||||
for _, rip := range resolvedIPs {
|
||||
if !rip.Family().IsIPv6() {
|
||||
continue
|
||||
}
|
||||
@ -192,8 +302,9 @@ func NewIPv4Matcher(ipnet *net.IPNetTable, onSource bool) *IPv4Matcher {
|
||||
|
||||
func (v *IPv4Matcher) Apply(ctx context.Context) bool {
|
||||
ips := make([]net.IP, 0, 4)
|
||||
if resolveIPs, ok := proxy.ResolvedIPsFromContext(ctx); ok {
|
||||
for _, rip := range resolveIPs {
|
||||
if resolver, ok := proxy.ResolvedIPsFromContext(ctx); ok {
|
||||
resolvedIPs := resolver.Resolve()
|
||||
for _, rip := range resolvedIPs {
|
||||
if !rip.Family().IsIPv4() {
|
||||
continue
|
||||
}
|
||||
|
@ -2,57 +2,66 @@ package router_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
. "v2ray.com/core/app/router"
|
||||
"v2ray.com/core/common"
|
||||
"v2ray.com/core/common/errors"
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/platform"
|
||||
"v2ray.com/core/common/protocol"
|
||||
"v2ray.com/core/proxy"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
"v2ray.com/ext/sysio"
|
||||
)
|
||||
|
||||
func TestSubDomainMatcher(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
cases := []struct {
|
||||
pattern string
|
||||
input context.Context
|
||||
input string
|
||||
output bool
|
||||
}{
|
||||
{
|
||||
pattern: "v2ray.com",
|
||||
input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("www.v2ray.com"), 80)),
|
||||
input: "www.v2ray.com",
|
||||
output: true,
|
||||
},
|
||||
{
|
||||
pattern: "v2ray.com",
|
||||
input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("v2ray.com"), 80)),
|
||||
input: "v2ray.com",
|
||||
output: true,
|
||||
},
|
||||
{
|
||||
pattern: "v2ray.com",
|
||||
input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("www.v3ray.com"), 80)),
|
||||
input: "www.v3ray.com",
|
||||
output: false,
|
||||
},
|
||||
{
|
||||
pattern: "v2ray.com",
|
||||
input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("2ray.com"), 80)),
|
||||
input: "2ray.com",
|
||||
output: false,
|
||||
},
|
||||
{
|
||||
pattern: "v2ray.com",
|
||||
input: proxy.ContextWithTarget(context.Background(), net.TCPDestination(net.DomainAddress("xv2ray.com"), 80)),
|
||||
input: "xv2ray.com",
|
||||
output: false,
|
||||
},
|
||||
}
|
||||
for _, test := range cases {
|
||||
matcher := NewSubDomainMatcher(test.pattern)
|
||||
assert.Bool(matcher.Apply(test.input) == test.output).IsTrue()
|
||||
assert(matcher.Apply(test.input) == test.output, IsTrue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoutingRule(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
type ruleTest struct {
|
||||
input context.Context
|
||||
@ -172,10 +181,56 @@ func TestRoutingRule(t *testing.T) {
|
||||
|
||||
for _, test := range cases {
|
||||
cond, err := test.rule.BuildCondition()
|
||||
assert.Error(err).IsNil()
|
||||
assert(err, IsNil)
|
||||
|
||||
for _, t := range test.test {
|
||||
assert.Bool(cond.Apply(t.input)).Equals(t.output)
|
||||
assert(cond.Apply(t.input), Equals, t.output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func loadGeoSite(country string) ([]*Domain, error) {
|
||||
geositeBytes, err := sysio.ReadAsset("geosite.dat")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var geositeList GeoSiteList
|
||||
if err := proto.Unmarshal(geositeBytes, &geositeList); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, site := range geositeList.Entry {
|
||||
if site.CountryCode == country {
|
||||
return site.Domain, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("country not found: " + country)
|
||||
}
|
||||
|
||||
func TestChinaSites(t *testing.T) {
|
||||
assert := With(t)
|
||||
|
||||
common.Must(sysio.CopyFile(platform.GetAssetLocation("geosite.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "tools", "release", "config", "geosite.dat")))
|
||||
|
||||
domains, err := loadGeoSite("CN")
|
||||
assert(err, IsNil)
|
||||
|
||||
matcher := NewCachableDomainMatcher()
|
||||
for _, d := range domains {
|
||||
assert(matcher.Add(d), IsNil)
|
||||
}
|
||||
|
||||
assert(matcher.ApplyDomain("163.com"), IsTrue)
|
||||
assert(matcher.ApplyDomain("163.com"), IsTrue)
|
||||
assert(matcher.ApplyDomain("164.com"), IsFalse)
|
||||
assert(matcher.ApplyDomain("164.com"), IsFalse)
|
||||
|
||||
for i := 0; i < 1024; i++ {
|
||||
assert(matcher.ApplyDomain(strconv.Itoa(i)+".not-exists.com"), IsFalse)
|
||||
}
|
||||
time.Sleep(time.Second * 10)
|
||||
for i := 0; i < 1024; i++ {
|
||||
assert(matcher.ApplyDomain(strconv.Itoa(i)+".not-exists2.com"), IsFalse)
|
||||
}
|
||||
}
|
||||
|
@ -52,24 +52,11 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
|
||||
conds := NewConditionChan()
|
||||
|
||||
if len(rr.Domain) > 0 {
|
||||
anyCond := NewAnyCondition()
|
||||
matcher := NewCachableDomainMatcher()
|
||||
for _, domain := range rr.Domain {
|
||||
switch domain.Type {
|
||||
case Domain_Plain:
|
||||
anyCond.Add(NewPlainDomainMatcher(domain.Value))
|
||||
case Domain_Regex:
|
||||
matcher, err := NewRegexpDomainMatcher(domain.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
anyCond.Add(matcher)
|
||||
case Domain_Domain:
|
||||
anyCond.Add(NewSubDomainMatcher(domain.Value))
|
||||
default:
|
||||
panic("Unknown domain type.")
|
||||
}
|
||||
matcher.Add(domain)
|
||||
}
|
||||
conds.Add(anyCond)
|
||||
conds.Add(matcher)
|
||||
}
|
||||
|
||||
if len(rr.Cidr) > 0 {
|
||||
|
@ -54,23 +54,27 @@ const (
|
||||
Config_UseIp Config_DomainStrategy = 1
|
||||
// Resolve to IP if the domain doesn't match any rules.
|
||||
Config_IpIfNonMatch Config_DomainStrategy = 2
|
||||
// Resolve to IP if any rule requires IP matching.
|
||||
Config_IpOnDemand Config_DomainStrategy = 3
|
||||
)
|
||||
|
||||
var Config_DomainStrategy_name = map[int32]string{
|
||||
0: "AsIs",
|
||||
1: "UseIp",
|
||||
2: "IpIfNonMatch",
|
||||
3: "IpOnDemand",
|
||||
}
|
||||
var Config_DomainStrategy_value = map[string]int32{
|
||||
"AsIs": 0,
|
||||
"UseIp": 1,
|
||||
"IpIfNonMatch": 2,
|
||||
"IpOnDemand": 3,
|
||||
}
|
||||
|
||||
func (x Config_DomainStrategy) String() string {
|
||||
return proto.EnumName(Config_DomainStrategy_name, int32(x))
|
||||
}
|
||||
func (Config_DomainStrategy) EnumDescriptor() ([]byte, []int) { return fileDescriptor0, []int{3, 0} }
|
||||
func (Config_DomainStrategy) EnumDescriptor() ([]byte, []int) { return fileDescriptor0, []int{7, 0} }
|
||||
|
||||
// Domain for routing decision.
|
||||
type Domain struct {
|
||||
@ -126,6 +130,86 @@ func (m *CIDR) GetPrefix() uint32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
type GeoIP struct {
|
||||
CountryCode string `protobuf:"bytes,1,opt,name=country_code,json=countryCode" json:"country_code,omitempty"`
|
||||
Cidr []*CIDR `protobuf:"bytes,2,rep,name=cidr" json:"cidr,omitempty"`
|
||||
}
|
||||
|
||||
func (m *GeoIP) Reset() { *m = GeoIP{} }
|
||||
func (m *GeoIP) String() string { return proto.CompactTextString(m) }
|
||||
func (*GeoIP) ProtoMessage() {}
|
||||
func (*GeoIP) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
|
||||
|
||||
func (m *GeoIP) GetCountryCode() string {
|
||||
if m != nil {
|
||||
return m.CountryCode
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *GeoIP) GetCidr() []*CIDR {
|
||||
if m != nil {
|
||||
return m.Cidr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type GeoIPList struct {
|
||||
Entry []*GeoIP `protobuf:"bytes,1,rep,name=entry" json:"entry,omitempty"`
|
||||
}
|
||||
|
||||
func (m *GeoIPList) Reset() { *m = GeoIPList{} }
|
||||
func (m *GeoIPList) String() string { return proto.CompactTextString(m) }
|
||||
func (*GeoIPList) ProtoMessage() {}
|
||||
func (*GeoIPList) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
|
||||
|
||||
func (m *GeoIPList) GetEntry() []*GeoIP {
|
||||
if m != nil {
|
||||
return m.Entry
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type GeoSite struct {
|
||||
CountryCode string `protobuf:"bytes,1,opt,name=country_code,json=countryCode" json:"country_code,omitempty"`
|
||||
Domain []*Domain `protobuf:"bytes,2,rep,name=domain" json:"domain,omitempty"`
|
||||
}
|
||||
|
||||
func (m *GeoSite) Reset() { *m = GeoSite{} }
|
||||
func (m *GeoSite) String() string { return proto.CompactTextString(m) }
|
||||
func (*GeoSite) ProtoMessage() {}
|
||||
func (*GeoSite) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{4} }
|
||||
|
||||
func (m *GeoSite) GetCountryCode() string {
|
||||
if m != nil {
|
||||
return m.CountryCode
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *GeoSite) GetDomain() []*Domain {
|
||||
if m != nil {
|
||||
return m.Domain
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type GeoSiteList struct {
|
||||
Entry []*GeoSite `protobuf:"bytes,1,rep,name=entry" json:"entry,omitempty"`
|
||||
}
|
||||
|
||||
func (m *GeoSiteList) Reset() { *m = GeoSiteList{} }
|
||||
func (m *GeoSiteList) String() string { return proto.CompactTextString(m) }
|
||||
func (*GeoSiteList) ProtoMessage() {}
|
||||
func (*GeoSiteList) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{5} }
|
||||
|
||||
func (m *GeoSiteList) GetEntry() []*GeoSite {
|
||||
if m != nil {
|
||||
return m.Entry
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type RoutingRule struct {
|
||||
Tag string `protobuf:"bytes,1,opt,name=tag" json:"tag,omitempty"`
|
||||
Domain []*Domain `protobuf:"bytes,2,rep,name=domain" json:"domain,omitempty"`
|
||||
@ -140,7 +224,7 @@ type RoutingRule struct {
|
||||
func (m *RoutingRule) Reset() { *m = RoutingRule{} }
|
||||
func (m *RoutingRule) String() string { return proto.CompactTextString(m) }
|
||||
func (*RoutingRule) ProtoMessage() {}
|
||||
func (*RoutingRule) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{2} }
|
||||
func (*RoutingRule) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{6} }
|
||||
|
||||
func (m *RoutingRule) GetTag() string {
|
||||
if m != nil {
|
||||
@ -206,7 +290,7 @@ type Config struct {
|
||||
func (m *Config) Reset() { *m = Config{} }
|
||||
func (m *Config) String() string { return proto.CompactTextString(m) }
|
||||
func (*Config) ProtoMessage() {}
|
||||
func (*Config) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{3} }
|
||||
func (*Config) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{7} }
|
||||
|
||||
func (m *Config) GetDomainStrategy() Config_DomainStrategy {
|
||||
if m != nil {
|
||||
@ -225,6 +309,10 @@ func (m *Config) GetRule() []*RoutingRule {
|
||||
func init() {
|
||||
proto.RegisterType((*Domain)(nil), "v2ray.core.app.router.Domain")
|
||||
proto.RegisterType((*CIDR)(nil), "v2ray.core.app.router.CIDR")
|
||||
proto.RegisterType((*GeoIP)(nil), "v2ray.core.app.router.GeoIP")
|
||||
proto.RegisterType((*GeoIPList)(nil), "v2ray.core.app.router.GeoIPList")
|
||||
proto.RegisterType((*GeoSite)(nil), "v2ray.core.app.router.GeoSite")
|
||||
proto.RegisterType((*GeoSiteList)(nil), "v2ray.core.app.router.GeoSiteList")
|
||||
proto.RegisterType((*RoutingRule)(nil), "v2ray.core.app.router.RoutingRule")
|
||||
proto.RegisterType((*Config)(nil), "v2ray.core.app.router.Config")
|
||||
proto.RegisterEnum("v2ray.core.app.router.Domain_Type", Domain_Type_name, Domain_Type_value)
|
||||
@ -234,39 +322,45 @@ func init() {
|
||||
func init() { proto.RegisterFile("v2ray.com/core/app/router/config.proto", fileDescriptor0) }
|
||||
|
||||
var fileDescriptor0 = []byte{
|
||||
// 538 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x93, 0xc1, 0x6e, 0xd4, 0x30,
|
||||
0x10, 0x86, 0x49, 0x76, 0x1b, 0xba, 0x93, 0xb2, 0x44, 0x16, 0x45, 0xa1, 0xa8, 0x22, 0x8a, 0x10,
|
||||
0xe4, 0x80, 0x12, 0x69, 0x11, 0x70, 0x01, 0xa1, 0xb2, 0xed, 0x61, 0x25, 0xa8, 0x2a, 0xd3, 0x72,
|
||||
0xe0, 0x12, 0xb9, 0x59, 0x37, 0x58, 0x24, 0xb6, 0xe5, 0x38, 0xa5, 0x7b, 0xe3, 0x05, 0x78, 0x11,
|
||||
0x9e, 0x86, 0x47, 0x42, 0xb6, 0x53, 0xd1, 0xa2, 0x2e, 0xdc, 0x66, 0x9c, 0xef, 0x9f, 0x19, 0x8f,
|
||||
0xff, 0xc0, 0x93, 0xf3, 0x99, 0x22, 0xab, 0xbc, 0x12, 0x6d, 0x51, 0x09, 0x45, 0x0b, 0x22, 0x65,
|
||||
0xa1, 0x44, 0xaf, 0xa9, 0x2a, 0x2a, 0xc1, 0xcf, 0x58, 0x9d, 0x4b, 0x25, 0xb4, 0x40, 0xdb, 0x97,
|
||||
0x9c, 0xa2, 0x39, 0x91, 0x32, 0x77, 0xcc, 0xce, 0xe3, 0xbf, 0xe4, 0x95, 0x68, 0x5b, 0xc1, 0x0b,
|
||||
0x4e, 0x75, 0x21, 0x85, 0xd2, 0x4e, 0xbc, 0xf3, 0x74, 0x3d, 0xc5, 0xa9, 0xfe, 0x26, 0xd4, 0x57,
|
||||
0x07, 0xa6, 0xdf, 0x3d, 0x08, 0xf6, 0x45, 0x4b, 0x18, 0x47, 0x2f, 0x61, 0xac, 0x57, 0x92, 0xc6,
|
||||
0x5e, 0xe2, 0x65, 0xd3, 0x59, 0x9a, 0xdf, 0xd8, 0x3f, 0x77, 0x70, 0x7e, 0xbc, 0x92, 0x14, 0x5b,
|
||||
0x1e, 0xdd, 0x83, 0x8d, 0x73, 0xd2, 0xf4, 0x34, 0xf6, 0x13, 0x2f, 0x9b, 0x60, 0x97, 0xa4, 0x19,
|
||||
0x8c, 0x0d, 0x83, 0x26, 0xb0, 0x71, 0xd4, 0x10, 0xc6, 0xa3, 0x5b, 0x26, 0xc4, 0xb4, 0xa6, 0x17,
|
||||
0x91, 0x87, 0xe0, 0xb2, 0x6b, 0xe4, 0xa7, 0x39, 0x8c, 0xe7, 0x8b, 0x7d, 0x8c, 0xa6, 0xe0, 0x33,
|
||||
0x69, 0xbb, 0x6f, 0x61, 0x9f, 0x49, 0x74, 0x1f, 0x02, 0xa9, 0xe8, 0x19, 0xbb, 0xb0, 0x85, 0xef,
|
||||
0xe0, 0x21, 0x4b, 0x7f, 0x8c, 0x20, 0xc4, 0xa2, 0xd7, 0x8c, 0xd7, 0xb8, 0x6f, 0x28, 0x8a, 0x60,
|
||||
0xa4, 0x49, 0x6d, 0x85, 0x13, 0x6c, 0x42, 0xf4, 0x02, 0x82, 0xa5, 0xad, 0x1e, 0xfb, 0xc9, 0x28,
|
||||
0x0b, 0x67, 0xbb, 0xff, 0xbc, 0x0b, 0x1e, 0x60, 0x54, 0xc0, 0xb8, 0x62, 0x4b, 0x15, 0x8f, 0xac,
|
||||
0xe8, 0xe1, 0x1a, 0x91, 0x99, 0x15, 0x5b, 0x10, 0xbd, 0x05, 0x30, 0x3b, 0x2f, 0x15, 0xe1, 0x35,
|
||||
0x8d, 0xc7, 0x89, 0x97, 0x85, 0xb3, 0xe4, 0xaa, 0xcc, 0xad, 0x3d, 0xe7, 0x54, 0xe7, 0x47, 0x42,
|
||||
0x69, 0x6c, 0x38, 0x3c, 0x91, 0x97, 0x21, 0x3a, 0x80, 0xad, 0xe1, 0x39, 0xca, 0x86, 0x75, 0x3a,
|
||||
0xde, 0xb0, 0x25, 0xd2, 0x35, 0x25, 0x0e, 0x1d, 0xfa, 0x9e, 0x75, 0x1a, 0x87, 0xfc, 0x4f, 0x82,
|
||||
0x5e, 0x43, 0xd8, 0x89, 0x5e, 0x55, 0xb4, 0xb4, 0xf3, 0x07, 0xff, 0x9f, 0x1f, 0x1c, 0x3f, 0x37,
|
||||
0xb7, 0xd8, 0x05, 0xe8, 0x3b, 0xaa, 0x4a, 0xda, 0x12, 0xd6, 0xc4, 0xb7, 0x93, 0x51, 0x36, 0xc1,
|
||||
0x13, 0x73, 0x72, 0x60, 0x0e, 0xd0, 0x23, 0x08, 0x19, 0x3f, 0x15, 0x3d, 0x5f, 0x96, 0x66, 0xcd,
|
||||
0x9b, 0xf6, 0x3b, 0x0c, 0x47, 0xc7, 0xa4, 0x4e, 0x7f, 0x79, 0x10, 0xcc, 0xad, 0x73, 0xd1, 0x09,
|
||||
0xdc, 0x75, 0xbb, 0x2c, 0x3b, 0xad, 0x88, 0xa6, 0xf5, 0x6a, 0x70, 0xd3, 0xb3, 0x75, 0xc3, 0x38,
|
||||
0xc7, 0xbb, 0x87, 0xf8, 0x38, 0x68, 0xf0, 0x74, 0x79, 0x2d, 0x37, 0xce, 0x54, 0x7d, 0x43, 0x87,
|
||||
0xd7, 0x5c, 0xe7, 0xcc, 0x2b, 0x9e, 0xc0, 0x96, 0x4f, 0x5f, 0xc1, 0xf4, 0x7a, 0x65, 0xb4, 0x09,
|
||||
0xe3, 0xbd, 0x6e, 0xd1, 0x39, 0x33, 0x9e, 0x74, 0x74, 0x21, 0x23, 0x0f, 0x45, 0xb0, 0xb5, 0x90,
|
||||
0x8b, 0xb3, 0x43, 0xc1, 0x3f, 0x10, 0x5d, 0x7d, 0x89, 0xfc, 0x77, 0x6f, 0xe0, 0x41, 0x25, 0xda,
|
||||
0x9b, 0xfb, 0x1c, 0x79, 0x9f, 0x03, 0x17, 0xfd, 0xf4, 0xb7, 0x3f, 0xcd, 0x30, 0x59, 0xe5, 0x73,
|
||||
0x43, 0xec, 0x49, 0x69, 0x47, 0xa0, 0xea, 0x34, 0xb0, 0xff, 0xd6, 0xf3, 0xdf, 0x01, 0x00, 0x00,
|
||||
0xff, 0xff, 0xa7, 0x6a, 0x97, 0x93, 0xeb, 0x03, 0x00, 0x00,
|
||||
// 640 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x94, 0xcd, 0x6e, 0xd4, 0x3a,
|
||||
0x14, 0xc7, 0x6f, 0xe6, 0xab, 0x9d, 0x93, 0xb9, 0x73, 0x23, 0xeb, 0x16, 0x0d, 0x85, 0xc2, 0x10,
|
||||
0x21, 0x98, 0x05, 0x4a, 0xa4, 0xe1, 0x63, 0x05, 0xaa, 0xca, 0xb4, 0xaa, 0x22, 0x41, 0x19, 0xb9,
|
||||
0x2d, 0x0b, 0x58, 0x44, 0x6e, 0xe2, 0x86, 0x88, 0x89, 0x6d, 0x39, 0x4e, 0xe9, 0xec, 0x78, 0x01,
|
||||
0x5e, 0x84, 0xa7, 0xe2, 0x51, 0x90, 0xed, 0x0c, 0xb4, 0xa8, 0x81, 0x8a, 0x9d, 0xed, 0xfc, 0xfe,
|
||||
0xe7, 0xfc, 0x73, 0x7c, 0x8e, 0xe1, 0xc1, 0xd9, 0x54, 0x92, 0x65, 0x90, 0xf0, 0x22, 0x4c, 0xb8,
|
||||
0xa4, 0x21, 0x11, 0x22, 0x94, 0xbc, 0x52, 0x54, 0x86, 0x09, 0x67, 0xa7, 0x79, 0x16, 0x08, 0xc9,
|
||||
0x15, 0x47, 0x1b, 0x2b, 0x4e, 0xd2, 0x80, 0x08, 0x11, 0x58, 0x66, 0xf3, 0xfe, 0x2f, 0xf2, 0x84,
|
||||
0x17, 0x05, 0x67, 0x21, 0xa3, 0x2a, 0x14, 0x5c, 0x2a, 0x2b, 0xde, 0x7c, 0xd8, 0x4c, 0x31, 0xaa,
|
||||
0x3e, 0x71, 0xf9, 0xd1, 0x82, 0xfe, 0x67, 0x07, 0x7a, 0xbb, 0xbc, 0x20, 0x39, 0x43, 0xcf, 0xa0,
|
||||
0xa3, 0x96, 0x82, 0x8e, 0x9c, 0xb1, 0x33, 0x19, 0x4e, 0xfd, 0xe0, 0xca, 0xfc, 0x81, 0x85, 0x83,
|
||||
0xa3, 0xa5, 0xa0, 0xd8, 0xf0, 0xe8, 0x7f, 0xe8, 0x9e, 0x91, 0x45, 0x45, 0x47, 0xad, 0xb1, 0x33,
|
||||
0xe9, 0x63, 0xbb, 0xf1, 0x27, 0xd0, 0xd1, 0x0c, 0xea, 0x43, 0x77, 0xbe, 0x20, 0x39, 0xf3, 0xfe,
|
||||
0xd1, 0x4b, 0x4c, 0x33, 0x7a, 0xee, 0x39, 0x08, 0x56, 0x59, 0xbd, 0x96, 0x1f, 0x40, 0x67, 0x16,
|
||||
0xed, 0x62, 0x34, 0x84, 0x56, 0x2e, 0x4c, 0xf6, 0x01, 0x6e, 0xe5, 0x02, 0xdd, 0x80, 0x9e, 0x90,
|
||||
0xf4, 0x34, 0x3f, 0x37, 0x81, 0xff, 0xc5, 0xf5, 0xce, 0x7f, 0x0f, 0xdd, 0x7d, 0xca, 0xa3, 0x39,
|
||||
0xba, 0x07, 0x83, 0x84, 0x57, 0x4c, 0xc9, 0x65, 0x9c, 0xf0, 0xd4, 0x1a, 0xef, 0x63, 0xb7, 0x3e,
|
||||
0x9b, 0xf1, 0x94, 0xa2, 0x10, 0x3a, 0x49, 0x9e, 0xca, 0x51, 0x6b, 0xdc, 0x9e, 0xb8, 0xd3, 0x5b,
|
||||
0x0d, 0xff, 0xa4, 0xd3, 0x63, 0x03, 0xfa, 0xdb, 0xd0, 0x37, 0xc1, 0x5f, 0xe5, 0xa5, 0x42, 0x53,
|
||||
0xe8, 0x52, 0x1d, 0x6a, 0xe4, 0x18, 0xf9, 0xed, 0x06, 0xb9, 0x11, 0x60, 0x8b, 0xfa, 0x09, 0xac,
|
||||
0xed, 0x53, 0x7e, 0x98, 0x2b, 0x7a, 0x1d, 0x7f, 0x4f, 0xa1, 0x97, 0x9a, 0x3a, 0xd4, 0x0e, 0xb7,
|
||||
0x7e, 0x5b, 0x75, 0x5c, 0xc3, 0xfe, 0x0c, 0xdc, 0x3a, 0x89, 0xf1, 0xf9, 0xe4, 0xb2, 0xcf, 0x3b,
|
||||
0xcd, 0x3e, 0xb5, 0x64, 0xe5, 0xf4, 0x4b, 0x1b, 0x5c, 0xcc, 0x2b, 0x95, 0xb3, 0x0c, 0x57, 0x0b,
|
||||
0x8a, 0x3c, 0x68, 0x2b, 0x92, 0xd5, 0x2e, 0xf5, 0xf2, 0x2f, 0xdd, 0xfd, 0x28, 0x7a, 0xfb, 0x9a,
|
||||
0x45, 0x47, 0xdb, 0x00, 0xba, 0x77, 0x63, 0x49, 0x58, 0x46, 0x47, 0x9d, 0xb1, 0x33, 0x71, 0xa7,
|
||||
0xe3, 0x8b, 0x32, 0xdb, 0xbe, 0x01, 0xa3, 0x2a, 0x98, 0x73, 0xa9, 0xb0, 0xe6, 0x70, 0x5f, 0xac,
|
||||
0x96, 0x68, 0x0f, 0x06, 0x75, 0x5b, 0xc7, 0x8b, 0xbc, 0x54, 0xa3, 0xae, 0x09, 0xe1, 0x37, 0x84,
|
||||
0x38, 0xb0, 0xa8, 0x2e, 0x1d, 0x76, 0xd9, 0xcf, 0x0d, 0x7a, 0x0e, 0x6e, 0xc9, 0x2b, 0x99, 0xd0,
|
||||
0xd8, 0xf8, 0xef, 0xfd, 0xd9, 0x3f, 0x58, 0x7e, 0xa6, 0xff, 0x62, 0x0b, 0xa0, 0x2a, 0xa9, 0x8c,
|
||||
0x69, 0x41, 0xf2, 0xc5, 0x68, 0x6d, 0xdc, 0x9e, 0xf4, 0x71, 0x5f, 0x9f, 0xec, 0xe9, 0x03, 0x74,
|
||||
0x17, 0xdc, 0x9c, 0x9d, 0xf0, 0x8a, 0xa5, 0xb1, 0x2e, 0xf3, 0xba, 0xf9, 0x0e, 0xf5, 0xd1, 0x11,
|
||||
0xc9, 0xfc, 0x6f, 0x0e, 0xf4, 0x66, 0xe6, 0x05, 0x40, 0xc7, 0xf0, 0x9f, 0xad, 0x65, 0x5c, 0x2a,
|
||||
0x49, 0x14, 0xcd, 0x96, 0xf5, 0x54, 0x3e, 0x6a, 0x32, 0x63, 0x5f, 0x0e, 0x7b, 0x11, 0x87, 0xb5,
|
||||
0x06, 0x0f, 0xd3, 0x4b, 0x7b, 0x3d, 0xe1, 0xb2, 0x5a, 0xd0, 0xfa, 0x36, 0x9b, 0x26, 0xfc, 0x42,
|
||||
0x4f, 0x60, 0xc3, 0xfb, 0xfb, 0x30, 0xbc, 0x1c, 0x19, 0xad, 0x43, 0x67, 0xa7, 0x8c, 0x4a, 0x3b,
|
||||
0xd4, 0xc7, 0x25, 0x8d, 0x84, 0xe7, 0x20, 0x0f, 0x06, 0x91, 0x88, 0x4e, 0x0f, 0x38, 0x7b, 0x4d,
|
||||
0x54, 0xf2, 0xc1, 0x6b, 0xa1, 0x21, 0x40, 0x24, 0xde, 0xb0, 0x5d, 0x5a, 0x10, 0x96, 0x7a, 0xed,
|
||||
0x97, 0x2f, 0xe0, 0x66, 0xc2, 0x8b, 0xab, 0xf3, 0xce, 0x9d, 0x77, 0x3d, 0xbb, 0xfa, 0xda, 0xda,
|
||||
0x78, 0x3b, 0xc5, 0x64, 0x19, 0xcc, 0x34, 0xb1, 0x23, 0x84, 0xb1, 0x44, 0xe5, 0x49, 0xcf, 0xbc,
|
||||
0x59, 0x8f, 0xbf, 0x07, 0x00, 0x00, 0xff, 0xff, 0x53, 0x7c, 0xa8, 0x94, 0x43, 0x05, 0x00, 0x00,
|
||||
}
|
||||
|
@ -37,6 +37,24 @@ message CIDR {
|
||||
uint32 prefix = 2;
|
||||
}
|
||||
|
||||
message GeoIP {
|
||||
string country_code = 1;
|
||||
repeated CIDR cidr = 2;
|
||||
}
|
||||
|
||||
message GeoIPList {
|
||||
repeated GeoIP entry = 1;
|
||||
}
|
||||
|
||||
message GeoSite {
|
||||
string country_code = 1;
|
||||
repeated Domain domain = 2;
|
||||
}
|
||||
|
||||
message GeoSiteList{
|
||||
repeated GeoSite entry = 1;
|
||||
}
|
||||
|
||||
message RoutingRule {
|
||||
string tag = 1;
|
||||
repeated Domain domain = 2;
|
||||
@ -58,6 +76,9 @@ message Config {
|
||||
|
||||
// Resolve to IP if the domain doesn't match any rules.
|
||||
IpIfNonMatch = 2;
|
||||
|
||||
// Resolve to IP if any rule requires IP matching.
|
||||
IpOnDemand = 3;
|
||||
}
|
||||
DomainStrategy domain_strategy = 1;
|
||||
repeated RoutingRule rule = 2;
|
||||
|
@ -33,7 +33,7 @@ func NewRouter(ctx context.Context, config *Config) (*Router, error) {
|
||||
rules: make([]Rule, len(config.Rule)),
|
||||
}
|
||||
|
||||
space.OnInitialize(func() error {
|
||||
space.On(app.SpaceInitializing, func(interface{}) error {
|
||||
for idx, rule := range config.Rule {
|
||||
r.rules[idx].Tag = rule.Tag
|
||||
cond, err := rule.BuildCondition()
|
||||
@ -52,19 +52,42 @@ func NewRouter(ctx context.Context, config *Config) (*Router, error) {
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *Router) resolveIP(dest net.Destination) []net.Address {
|
||||
ips := r.dnsServer.Get(dest.Address.Domain())
|
||||
type ipResolver struct {
|
||||
ip []net.Address
|
||||
domain string
|
||||
resolved bool
|
||||
dnsServer dns.Server
|
||||
}
|
||||
|
||||
func (r *ipResolver) Resolve() []net.Address {
|
||||
if r.resolved {
|
||||
return r.ip
|
||||
}
|
||||
|
||||
log.Trace(newError("looking for IP for domain: ", r.domain))
|
||||
r.resolved = true
|
||||
ips := r.dnsServer.Get(r.domain)
|
||||
if len(ips) == 0 {
|
||||
return nil
|
||||
}
|
||||
dests := make([]net.Address, len(ips))
|
||||
for idx, ip := range ips {
|
||||
dests[idx] = net.IPAddress(ip)
|
||||
r.ip = make([]net.Address, len(ips))
|
||||
for i, ip := range ips {
|
||||
r.ip[i] = net.IPAddress(ip)
|
||||
}
|
||||
return dests
|
||||
return r.ip
|
||||
}
|
||||
|
||||
func (r *Router) TakeDetour(ctx context.Context) (string, error) {
|
||||
resolver := &ipResolver{
|
||||
dnsServer: r.dnsServer,
|
||||
}
|
||||
if r.domainStrategy == Config_IpOnDemand {
|
||||
if dest, ok := proxy.TargetFromContext(ctx); ok && dest.Address.Family().IsDomain() {
|
||||
resolver.domain = dest.Address.Domain()
|
||||
ctx = proxy.ContextWithResolveIPs(ctx, resolver)
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range r.rules {
|
||||
if rule.Apply(ctx) {
|
||||
return rule.Tag, nil
|
||||
@ -77,10 +100,10 @@ func (r *Router) TakeDetour(ctx context.Context) (string, error) {
|
||||
}
|
||||
|
||||
if r.domainStrategy == Config_IpIfNonMatch && dest.Address.Family().IsDomain() {
|
||||
log.Trace(newError("looking up IP for ", dest))
|
||||
ipDests := r.resolveIP(dest)
|
||||
if ipDests != nil {
|
||||
ctx = proxy.ContextWithResolveIPs(ctx, ipDests)
|
||||
resolver.domain = dest.Address.Domain()
|
||||
ips := resolver.Resolve()
|
||||
if len(ips) > 0 {
|
||||
ctx = proxy.ContextWithResolveIPs(ctx, resolver)
|
||||
for _, rule := range r.rules {
|
||||
if rule.Apply(ctx) {
|
||||
return rule.Tag, nil
|
||||
|
@ -14,11 +14,11 @@ import (
|
||||
. "v2ray.com/core/app/router"
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/proxy"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestSimpleRouter(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
config := &Config{
|
||||
Rule: []*RoutingRule{
|
||||
@ -33,16 +33,16 @@ func TestSimpleRouter(t *testing.T) {
|
||||
|
||||
space := app.NewSpace()
|
||||
ctx := app.ContextWithSpace(context.Background(), space)
|
||||
assert.Error(app.AddApplicationToSpace(ctx, new(dns.Config))).IsNil()
|
||||
assert.Error(app.AddApplicationToSpace(ctx, new(dispatcher.Config))).IsNil()
|
||||
assert.Error(app.AddApplicationToSpace(ctx, new(proxyman.OutboundConfig))).IsNil()
|
||||
assert.Error(app.AddApplicationToSpace(ctx, config)).IsNil()
|
||||
assert.Error(space.Initialize()).IsNil()
|
||||
assert(app.AddApplicationToSpace(ctx, new(dns.Config)), IsNil)
|
||||
assert(app.AddApplicationToSpace(ctx, new(dispatcher.Config)), IsNil)
|
||||
assert(app.AddApplicationToSpace(ctx, new(proxyman.OutboundConfig)), IsNil)
|
||||
assert(app.AddApplicationToSpace(ctx, config), IsNil)
|
||||
assert(space.Initialize(), IsNil)
|
||||
|
||||
r := FromSpace(space)
|
||||
|
||||
ctx = proxy.ContextWithTarget(ctx, net.TCPDestination(net.DomainAddress("v2ray.com"), 80))
|
||||
tag, err := r.TakeDetour(ctx)
|
||||
assert.Error(err).IsNil()
|
||||
assert.String(tag).Equals("test")
|
||||
assert(err, IsNil)
|
||||
assert(tag, Equals, "test")
|
||||
}
|
||||
|
38
app/space.go
38
app/space.go
@ -5,6 +5,7 @@ import (
|
||||
"reflect"
|
||||
|
||||
"v2ray.com/core/common"
|
||||
"v2ray.com/core/common/event"
|
||||
)
|
||||
|
||||
type Application interface {
|
||||
@ -13,8 +14,6 @@ type Application interface {
|
||||
Close()
|
||||
}
|
||||
|
||||
type InitializationCallback func() error
|
||||
|
||||
func CreateAppFromConfig(ctx context.Context, config interface{}) (Application, error) {
|
||||
application, err := common.CreateObject(ctx, config)
|
||||
if err != nil {
|
||||
@ -29,46 +28,47 @@ func CreateAppFromConfig(ctx context.Context, config interface{}) (Application,
|
||||
}
|
||||
|
||||
// A Space contains all apps that may be available in a V2Ray runtime.
|
||||
// Caller must check the availability of an app by calling HasXXX before getting its instance.
|
||||
type Space interface {
|
||||
event.Registry
|
||||
GetApplication(appInterface interface{}) Application
|
||||
AddApplication(application Application) error
|
||||
Initialize() error
|
||||
OnInitialize(InitializationCallback)
|
||||
Start() error
|
||||
Close()
|
||||
}
|
||||
|
||||
const (
|
||||
// SpaceInitializing is an event to be fired when Space is being initialized.
|
||||
SpaceInitializing event.Event = iota
|
||||
)
|
||||
|
||||
type spaceImpl struct {
|
||||
initialized bool
|
||||
event.Listener
|
||||
cache map[reflect.Type]Application
|
||||
appInit []InitializationCallback
|
||||
initialized bool
|
||||
}
|
||||
|
||||
// NewSpace creates a new Space.
|
||||
func NewSpace() Space {
|
||||
return &spaceImpl{
|
||||
cache: make(map[reflect.Type]Application),
|
||||
appInit: make([]InitializationCallback, 0, 32),
|
||||
cache: make(map[reflect.Type]Application),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *spaceImpl) OnInitialize(f InitializationCallback) {
|
||||
if s.initialized {
|
||||
f()
|
||||
} else {
|
||||
s.appInit = append(s.appInit, f)
|
||||
func (s *spaceImpl) On(e event.Event, h event.Handler) {
|
||||
if e == SpaceInitializing && s.initialized {
|
||||
_ = h(nil) // Ignore error
|
||||
return
|
||||
}
|
||||
s.Listener.On(e, h)
|
||||
}
|
||||
|
||||
func (s *spaceImpl) Initialize() error {
|
||||
for _, f := range s.appInit {
|
||||
if err := f(); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.initialized {
|
||||
return nil
|
||||
}
|
||||
s.appInit = nil
|
||||
s.initialized = true
|
||||
return nil
|
||||
return s.Fire(SpaceInitializing, nil)
|
||||
}
|
||||
|
||||
func (s *spaceImpl) GetApplication(appInterface interface{}) Application {
|
||||
|
@ -1,52 +0,0 @@
|
||||
package vpndialer
|
||||
|
||||
import proto "github.com/golang/protobuf/proto"
|
||||
import fmt "fmt"
|
||||
import math "math"
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal
|
||||
var _ = fmt.Errorf
|
||||
var _ = math.Inf
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
|
||||
|
||||
type Config struct {
|
||||
Address string `protobuf:"bytes,1,opt,name=address" json:"address,omitempty"`
|
||||
}
|
||||
|
||||
func (m *Config) Reset() { *m = Config{} }
|
||||
func (m *Config) String() string { return proto.CompactTextString(m) }
|
||||
func (*Config) ProtoMessage() {}
|
||||
func (*Config) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
|
||||
|
||||
func (m *Config) GetAddress() string {
|
||||
if m != nil {
|
||||
return m.Address
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func init() {
|
||||
proto.RegisterType((*Config)(nil), "v2ray.core.app.vpndialer.Config")
|
||||
}
|
||||
|
||||
func init() { proto.RegisterFile("v2ray.com/core/app/vpndialer/config.proto", fileDescriptor0) }
|
||||
|
||||
var fileDescriptor0 = []byte{
|
||||
// 150 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xd2, 0x2c, 0x33, 0x2a, 0x4a,
|
||||
0xac, 0xd4, 0x4b, 0xce, 0xcf, 0xd5, 0x4f, 0xce, 0x2f, 0x4a, 0xd5, 0x4f, 0x2c, 0x28, 0xd0, 0x2f,
|
||||
0x2b, 0xc8, 0x4b, 0xc9, 0x4c, 0xcc, 0x49, 0x2d, 0xd2, 0x4f, 0xce, 0xcf, 0x4b, 0xcb, 0x4c, 0xd7,
|
||||
0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x92, 0x80, 0x29, 0x2d, 0x4a, 0xd5, 0x4b, 0x2c, 0x28, 0xd0,
|
||||
0x83, 0x2b, 0x53, 0x52, 0xe2, 0x62, 0x73, 0x06, 0xab, 0x14, 0x92, 0xe0, 0x62, 0x4f, 0x4c, 0x49,
|
||||
0x29, 0x4a, 0x2d, 0x2e, 0x96, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x0c, 0x82, 0x71, 0x9d, 0xdc, 0xb8,
|
||||
0x64, 0x92, 0xf3, 0x73, 0xf5, 0x70, 0x99, 0x11, 0xc0, 0x18, 0xc5, 0x09, 0xe7, 0xac, 0x62, 0x92,
|
||||
0x08, 0x33, 0x0a, 0x4a, 0xac, 0xd4, 0x73, 0x06, 0xa9, 0x73, 0x2c, 0x28, 0xd0, 0x0b, 0x2b, 0xc8,
|
||||
0x73, 0x01, 0x4b, 0x25, 0xb1, 0x81, 0x1d, 0x63, 0x0c, 0x08, 0x00, 0x00, 0xff, 0xff, 0x97, 0xfc,
|
||||
0x09, 0x70, 0xb9, 0x00, 0x00, 0x00,
|
||||
}
|
@ -1,11 +0,0 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package v2ray.core.app.vpndialer;
|
||||
option csharp_namespace = "V2Ray.Core.App.VpnDialer";
|
||||
option go_package = "vpndialer";
|
||||
option java_package = "com.v2ray.core.app.vpndialer";
|
||||
option java_multiple_files = true;
|
||||
|
||||
message Config {
|
||||
string address = 1;
|
||||
}
|
@ -1,7 +0,0 @@
|
||||
package unix
|
||||
|
||||
import "v2ray.com/core/common/errors"
|
||||
|
||||
func newError(values ...interface{}) *errors.Error {
|
||||
return errors.New(values...).Path("App", "VPNDialer", "Unix")
|
||||
}
|
@ -1,217 +0,0 @@
|
||||
package unix
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
"v2ray.com/core/app/vpndialer"
|
||||
"v2ray.com/core/common"
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/serial"
|
||||
"v2ray.com/core/transport/internet"
|
||||
)
|
||||
|
||||
//go:generate go run $GOPATH/src/v2ray.com/core/tools/generrorgen/main.go -pkg unix -path App,VPNDialer,Unix
|
||||
|
||||
type status int
|
||||
|
||||
const (
|
||||
statusNew status = iota
|
||||
statusOK
|
||||
statusFail
|
||||
)
|
||||
|
||||
type fdStatus struct {
|
||||
status status
|
||||
fd int
|
||||
callback chan<- error
|
||||
}
|
||||
|
||||
type protector struct {
|
||||
sync.Mutex
|
||||
address string
|
||||
conn *net.UnixConn
|
||||
status chan fdStatus
|
||||
}
|
||||
|
||||
func readFrom(conn *net.UnixConn, schan chan<- fdStatus) {
|
||||
var payload [6]byte
|
||||
for {
|
||||
_, err := conn.Read(payload[:])
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
fd := serial.BytesToInt(payload[1:5])
|
||||
s := status(payload[5])
|
||||
schan <- fdStatus{
|
||||
fd: fd,
|
||||
status: s,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *protector) dial() (*net.UnixConn, error) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
if m.conn != nil {
|
||||
return m.conn, nil
|
||||
}
|
||||
|
||||
conn, err := net.DialUnix("unix", nil, &net.UnixAddr{
|
||||
Name: m.address,
|
||||
Net: "unix",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.conn = conn
|
||||
m.status = make(chan fdStatus, 32)
|
||||
go readFrom(conn, m.status)
|
||||
go m.monitor(conn)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (m *protector) close() {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
if m.conn == nil {
|
||||
return
|
||||
}
|
||||
m.conn.Close()
|
||||
m.conn = nil
|
||||
}
|
||||
|
||||
func (m *protector) monitor(c *net.UnixConn) {
|
||||
pendingFd := make(map[int]chan<- error, 32)
|
||||
for s := range m.status {
|
||||
switch s.status {
|
||||
case statusNew:
|
||||
pendingFd[s.fd] = s.callback
|
||||
case statusOK:
|
||||
if c, f := pendingFd[s.fd]; f {
|
||||
close(c)
|
||||
delete(pendingFd, s.fd)
|
||||
}
|
||||
case statusFail:
|
||||
if c, f := pendingFd[s.fd]; f {
|
||||
c <- newError("failed to protect fd")
|
||||
close(c)
|
||||
delete(pendingFd, s.fd)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *protector) protect(fd int) error {
|
||||
conn, err := m.dial()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var payload [6]byte
|
||||
serial.IntToBytes(fd, payload[1:1])
|
||||
payload[5] = byte(statusNew)
|
||||
if _, err := conn.Write(payload[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
wait := make(chan error)
|
||||
m.status <- fdStatus{
|
||||
status: statusNew,
|
||||
fd: fd,
|
||||
callback: wait,
|
||||
}
|
||||
return <-wait
|
||||
}
|
||||
|
||||
type App struct {
|
||||
protector *protector
|
||||
dialer *Dialer
|
||||
}
|
||||
|
||||
func NewApp(ctx context.Context, config *vpndialer.Config) (*App, error) {
|
||||
a := &App{
|
||||
dialer: &Dialer{},
|
||||
protector: &protector{
|
||||
address: config.Address,
|
||||
},
|
||||
}
|
||||
a.dialer.protect = a.protector.protect
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (*App) Interface() interface{} {
|
||||
return (*App)(nil)
|
||||
}
|
||||
|
||||
func (a *App) Start() error {
|
||||
internet.UseAlternativeSystemDialer(a.dialer)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *App) Close() {
|
||||
internet.UseAlternativeSystemDialer(nil)
|
||||
}
|
||||
|
||||
type Dialer struct {
|
||||
protect func(fd int) error
|
||||
}
|
||||
|
||||
func socket(dest net.Destination) (int, error) {
|
||||
switch dest.Network {
|
||||
case net.Network_TCP:
|
||||
return unix.Socket(unix.AF_INET6, unix.SOCK_STREAM, unix.IPPROTO_TCP)
|
||||
case net.Network_UDP:
|
||||
return unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP)
|
||||
default:
|
||||
return 0, newError("unknown network ", dest.Network)
|
||||
}
|
||||
}
|
||||
|
||||
func getIP(addr net.Address) (net.IP, error) {
|
||||
if addr.Family().Either(net.AddressFamilyIPv4, net.AddressFamilyIPv6) {
|
||||
return addr.IP(), nil
|
||||
}
|
||||
ips, err := net.LookupIP(addr.Domain())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ips[0], nil
|
||||
}
|
||||
|
||||
func (d *Dialer) Dial(ctx context.Context, source net.Address, dest net.Destination) (net.Conn, error) {
|
||||
fd, err := socket(dest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := d.protect(fd); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ip, err := getIP(dest.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
addr := &unix.SockaddrInet6{
|
||||
Port: int(dest.Port),
|
||||
ZoneId: 0,
|
||||
}
|
||||
copy(addr.Addr[:], ip.To16())
|
||||
|
||||
if err := unix.Connect(fd, addr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
file := os.NewFile(uintptr(fd), "Socket")
|
||||
return net.FileConn(file)
|
||||
}
|
||||
|
||||
func init() {
|
||||
common.Must(common.RegisterConfig((*vpndialer.Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
|
||||
return NewApp(ctx, config.(*vpndialer.Config))
|
||||
}))
|
||||
}
|
@ -1 +0,0 @@
|
||||
package vpndialer
|
21
common/bitmask/byte.go
Normal file
21
common/bitmask/byte.go
Normal file
@ -0,0 +1,21 @@
|
||||
package bitmask
|
||||
|
||||
// Byte is a bitmask in byte.
|
||||
type Byte byte
|
||||
|
||||
// Has returns true if this bitmask contains another bitmask.
|
||||
func (b Byte) Has(bb Byte) bool {
|
||||
return (b & bb) != 0
|
||||
}
|
||||
|
||||
func (b *Byte) Set(bb Byte) {
|
||||
*b |= bb
|
||||
}
|
||||
|
||||
func (b *Byte) Clear(bb Byte) {
|
||||
*b &= ^bb
|
||||
}
|
||||
|
||||
func (b *Byte) Toggle(bb Byte) {
|
||||
*b ^= bb
|
||||
}
|
27
common/bitmask/byte_test.go
Normal file
27
common/bitmask/byte_test.go
Normal file
@ -0,0 +1,27 @@
|
||||
package bitmask_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "v2ray.com/core/common/bitmask"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestBitmaskByte(t *testing.T) {
|
||||
assert := With(t)
|
||||
|
||||
b := Byte(0)
|
||||
b.Set(Byte(1))
|
||||
assert(b.Has(1), IsTrue)
|
||||
|
||||
b.Set(Byte(2))
|
||||
assert(b.Has(2), IsTrue)
|
||||
assert(b.Has(1), IsTrue)
|
||||
|
||||
b.Clear(Byte(1))
|
||||
assert(b.Has(2), IsTrue)
|
||||
assert(b.Has(1), IsFalse)
|
||||
|
||||
b.Toggle(Byte(2))
|
||||
assert(b.Has(2), IsFalse)
|
||||
}
|
@ -1,10 +1,7 @@
|
||||
package buf
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"v2ray.com/core/common/platform"
|
||||
)
|
||||
|
||||
// Pool provides functionality to generate and recycle buffers on demand.
|
||||
@ -45,79 +42,11 @@ func (p *SyncPool) Free(buffer *Buffer) {
|
||||
}
|
||||
}
|
||||
|
||||
// BufferPool is a Pool that utilizes an internal cache.
|
||||
type BufferPool struct {
|
||||
chain chan []byte
|
||||
sub Pool
|
||||
}
|
||||
|
||||
// NewBufferPool creates a new BufferPool with given buffer size, and internal cache size.
|
||||
func NewBufferPool(bufferSize, poolSize uint32) *BufferPool {
|
||||
pool := &BufferPool{
|
||||
chain: make(chan []byte, poolSize),
|
||||
sub: NewSyncPool(bufferSize),
|
||||
}
|
||||
for i := uint32(0); i < poolSize; i++ {
|
||||
pool.chain <- make([]byte, bufferSize)
|
||||
}
|
||||
return pool
|
||||
}
|
||||
|
||||
// Allocate implements Pool.Allocate().
|
||||
func (p *BufferPool) Allocate() *Buffer {
|
||||
select {
|
||||
case b := <-p.chain:
|
||||
return &Buffer{
|
||||
v: b,
|
||||
pool: p,
|
||||
}
|
||||
default:
|
||||
return p.sub.Allocate()
|
||||
}
|
||||
}
|
||||
|
||||
// Free implements Pool.Free().
|
||||
func (p *BufferPool) Free(buffer *Buffer) {
|
||||
if buffer.v == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case p.chain <- buffer.v:
|
||||
default:
|
||||
p.sub.Free(buffer)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// Size of a regular buffer.
|
||||
Size = 2 * 1024
|
||||
|
||||
poolSizeEnvKey = "v2ray.buffer.size"
|
||||
)
|
||||
|
||||
var (
|
||||
mediumPool Pool
|
||||
mediumPool Pool = NewSyncPool(Size)
|
||||
)
|
||||
|
||||
func getDefaultPoolSize() int {
|
||||
switch runtime.GOARCH {
|
||||
case "amd64", "386":
|
||||
return 20
|
||||
default:
|
||||
return 5
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
f := platform.EnvFlag{
|
||||
Name: poolSizeEnvKey,
|
||||
AltName: platform.NormalizeEnvName(poolSizeEnvKey),
|
||||
}
|
||||
size := f.GetValueAsInt(getDefaultPoolSize())
|
||||
if size > 0 {
|
||||
totalByteSize := uint32(size) * 1024 * 1024
|
||||
mediumPool = NewBufferPool(Size, totalByteSize/Size)
|
||||
} else {
|
||||
mediumPool = NewSyncPool(Size)
|
||||
}
|
||||
}
|
||||
|
@ -6,64 +6,64 @@ import (
|
||||
|
||||
. "v2ray.com/core/common/buf"
|
||||
"v2ray.com/core/common/serial"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestBufferClear(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
buffer := New()
|
||||
defer buffer.Release()
|
||||
|
||||
payload := "Bytes"
|
||||
buffer.Append([]byte(payload))
|
||||
assert.Int(buffer.Len()).Equals(len(payload))
|
||||
assert(buffer.Len(), Equals, len(payload))
|
||||
|
||||
buffer.Clear()
|
||||
assert.Int(buffer.Len()).Equals(0)
|
||||
assert(buffer.Len(), Equals, 0)
|
||||
}
|
||||
|
||||
func TestBufferIsEmpty(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
buffer := New()
|
||||
defer buffer.Release()
|
||||
|
||||
assert.Bool(buffer.IsEmpty()).IsTrue()
|
||||
assert(buffer.IsEmpty(), IsTrue)
|
||||
}
|
||||
|
||||
func TestBufferString(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
buffer := New()
|
||||
defer buffer.Release()
|
||||
|
||||
assert.Error(buffer.AppendSupplier(serial.WriteString("Test String"))).IsNil()
|
||||
assert.String(buffer.String()).Equals("Test String")
|
||||
assert(buffer.AppendSupplier(serial.WriteString("Test String")), IsNil)
|
||||
assert(buffer.String(), Equals, "Test String")
|
||||
}
|
||||
|
||||
func TestBufferWrite(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
buffer := NewLocal(8)
|
||||
nBytes, err := buffer.Write([]byte("abcd"))
|
||||
assert.Error(err).IsNil()
|
||||
assert.Int(nBytes).Equals(4)
|
||||
assert(err, IsNil)
|
||||
assert(nBytes, Equals, 4)
|
||||
nBytes, err = buffer.Write([]byte("abcde"))
|
||||
assert.Error(err).IsNil()
|
||||
assert.Int(nBytes).Equals(4)
|
||||
assert.String(buffer.String()).Equals("abcdabcd")
|
||||
assert(err, IsNil)
|
||||
assert(nBytes, Equals, 4)
|
||||
assert(buffer.String(), Equals, "abcdabcd")
|
||||
}
|
||||
|
||||
func TestSyncPool(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
p := NewSyncPool(32)
|
||||
b := p.Allocate()
|
||||
assert.Int(b.Len()).Equals(0)
|
||||
assert(b.Len(), Equals, 0)
|
||||
|
||||
assert.Error(b.AppendSupplier(ReadFrom(rand.Reader))).IsNil()
|
||||
assert.Int(b.Len()).Equals(32)
|
||||
assert(b.AppendSupplier(ReadFrom(rand.Reader)), IsNil)
|
||||
assert(b.Len(), Equals, 32)
|
||||
|
||||
b.Release()
|
||||
}
|
||||
|
@ -1,53 +0,0 @@
|
||||
package buf
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
// BufferedReader is a reader with internal cache.
|
||||
type BufferedReader struct {
|
||||
reader io.Reader
|
||||
buffer *Buffer
|
||||
buffered bool
|
||||
}
|
||||
|
||||
// NewBufferedReader creates a new BufferedReader based on an io.Reader.
|
||||
func NewBufferedReader(rawReader io.Reader) *BufferedReader {
|
||||
return &BufferedReader{
|
||||
reader: rawReader,
|
||||
buffer: NewLocal(1024),
|
||||
buffered: true,
|
||||
}
|
||||
}
|
||||
|
||||
// IsBuffered returns true if the internal cache is effective.
|
||||
func (v *BufferedReader) IsBuffered() bool {
|
||||
return v.buffered
|
||||
}
|
||||
|
||||
// SetBuffered is to enable or disable internal cache. If cache is disabled,
|
||||
// Read() calls will be delegated to the underlying io.Reader directly.
|
||||
func (v *BufferedReader) SetBuffered(cached bool) {
|
||||
v.buffered = cached
|
||||
}
|
||||
|
||||
// Read implements io.Reader.Read().
|
||||
func (v *BufferedReader) Read(b []byte) (int, error) {
|
||||
if !v.buffered || v.buffer == nil {
|
||||
if !v.buffer.IsEmpty() {
|
||||
return v.buffer.Read(b)
|
||||
}
|
||||
return v.reader.Read(b)
|
||||
}
|
||||
if v.buffer.IsEmpty() {
|
||||
if err := v.buffer.Reset(ReadFrom(v.reader)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
if v.buffer.IsEmpty() {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return v.buffer.Read(b)
|
||||
}
|
@ -1,36 +0,0 @@
|
||||
package buf_test
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
|
||||
. "v2ray.com/core/common/buf"
|
||||
"v2ray.com/core/testing/assert"
|
||||
)
|
||||
|
||||
func TestBufferedReader(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
|
||||
content := New()
|
||||
assert.Error(content.AppendSupplier(ReadFrom(rand.Reader))).IsNil()
|
||||
|
||||
len := content.Len()
|
||||
|
||||
reader := NewBufferedReader(content)
|
||||
assert.Bool(reader.IsBuffered()).IsTrue()
|
||||
|
||||
payload := make([]byte, 16)
|
||||
|
||||
nBytes, err := reader.Read(payload)
|
||||
assert.Int(nBytes).Equals(16)
|
||||
assert.Error(err).IsNil()
|
||||
|
||||
len2 := content.Len()
|
||||
assert.Int(len - len2).GreaterThan(16)
|
||||
|
||||
nBytes, err = reader.Read(payload)
|
||||
assert.Int(nBytes).Equals(16)
|
||||
assert.Error(err).IsNil()
|
||||
|
||||
assert.Int(content.Len()).Equals(len2)
|
||||
}
|
@ -1,72 +0,0 @@
|
||||
package buf
|
||||
|
||||
import "io"
|
||||
|
||||
// BufferedWriter is an io.Writer with internal buffer. It writes to underlying writer when buffer is full or on demand.
|
||||
// This type is not thread safe.
|
||||
type BufferedWriter struct {
|
||||
writer io.Writer
|
||||
buffer *Buffer
|
||||
buffered bool
|
||||
}
|
||||
|
||||
// NewBufferedWriter creates a new BufferedWriter.
|
||||
func NewBufferedWriter(writer io.Writer) *BufferedWriter {
|
||||
return NewBufferedWriterSize(writer, 1024)
|
||||
}
|
||||
|
||||
func NewBufferedWriterSize(writer io.Writer, size uint32) *BufferedWriter {
|
||||
return &BufferedWriter{
|
||||
writer: writer,
|
||||
buffer: NewLocal(int(size)),
|
||||
buffered: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
func (w *BufferedWriter) Write(b []byte) (int, error) {
|
||||
if !w.buffered || w.buffer == nil {
|
||||
return w.writer.Write(b)
|
||||
}
|
||||
bytesWritten := 0
|
||||
for bytesWritten < len(b) {
|
||||
nBytes, err := w.buffer.Write(b[bytesWritten:])
|
||||
if err != nil {
|
||||
return bytesWritten, err
|
||||
}
|
||||
bytesWritten += nBytes
|
||||
if w.buffer.IsFull() {
|
||||
if err := w.Flush(); err != nil {
|
||||
return bytesWritten, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return bytesWritten, nil
|
||||
}
|
||||
|
||||
// Flush writes all buffered content into underlying writer, if any.
|
||||
func (w *BufferedWriter) Flush() error {
|
||||
defer w.buffer.Clear()
|
||||
for !w.buffer.IsEmpty() {
|
||||
nBytes, err := w.writer.Write(w.buffer.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.buffer.SliceFrom(nBytes)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsBuffered returns true if this BufferedWriter holds a buffer.
|
||||
func (w *BufferedWriter) IsBuffered() bool {
|
||||
return w.buffered
|
||||
}
|
||||
|
||||
// SetBuffered controls whether the BufferedWriter holds a buffer for writing. If not buffered, any write() calls into underlying writer directly.
|
||||
func (w *BufferedWriter) SetBuffered(cached bool) error {
|
||||
w.buffered = cached
|
||||
if !cached && !w.buffer.IsEmpty() {
|
||||
return w.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
@ -1,53 +0,0 @@
|
||||
package buf_test
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
|
||||
. "v2ray.com/core/common/buf"
|
||||
"v2ray.com/core/testing/assert"
|
||||
)
|
||||
|
||||
func TestBufferedWriter(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
|
||||
content := New()
|
||||
|
||||
writer := NewBufferedWriter(content)
|
||||
assert.Bool(writer.IsBuffered()).IsTrue()
|
||||
|
||||
payload := make([]byte, 16)
|
||||
|
||||
nBytes, err := writer.Write(payload)
|
||||
assert.Int(nBytes).Equals(16)
|
||||
assert.Error(err).IsNil()
|
||||
|
||||
assert.Bool(content.IsEmpty()).IsTrue()
|
||||
|
||||
assert.Error(writer.SetBuffered(false)).IsNil()
|
||||
assert.Int(content.Len()).Equals(16)
|
||||
}
|
||||
|
||||
func TestBufferedWriterLargePayload(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
|
||||
content := NewLocal(128 * 1024)
|
||||
|
||||
writer := NewBufferedWriter(content)
|
||||
assert.Bool(writer.IsBuffered()).IsTrue()
|
||||
|
||||
payload := make([]byte, 64*1024)
|
||||
rand.Read(payload)
|
||||
|
||||
nBytes, err := writer.Write(payload[:512])
|
||||
assert.Int(nBytes).Equals(512)
|
||||
assert.Error(err).IsNil()
|
||||
|
||||
assert.Bool(content.IsEmpty()).IsTrue()
|
||||
|
||||
nBytes, err = writer.Write(payload[512:])
|
||||
assert.Error(err).IsNil()
|
||||
assert.Error(writer.Flush()).IsNil()
|
||||
assert.Int(nBytes).Equals(64*1024 - 512)
|
||||
assert.Bytes(content.Bytes()).Equals(payload)
|
||||
}
|
@ -17,7 +17,7 @@ type copyHandler struct {
|
||||
}
|
||||
|
||||
func (h *copyHandler) readFrom(reader Reader) (MultiBuffer, error) {
|
||||
mb, err := reader.Read()
|
||||
mb, err := reader.ReadMultiBuffer()
|
||||
if err != nil {
|
||||
for _, handler := range h.onReadError {
|
||||
err = handler(err)
|
||||
@ -27,7 +27,7 @@ func (h *copyHandler) readFrom(reader Reader) (MultiBuffer, error) {
|
||||
}
|
||||
|
||||
func (h *copyHandler) writeTo(writer Writer, mb MultiBuffer) error {
|
||||
err := writer.Write(mb)
|
||||
err := writer.WriteMultiBuffer(mb)
|
||||
if err != nil {
|
||||
for _, handler := range h.onWriteError {
|
||||
err = handler(err)
|
||||
@ -36,8 +36,14 @@ func (h *copyHandler) writeTo(writer Writer, mb MultiBuffer) error {
|
||||
return err
|
||||
}
|
||||
|
||||
type SizeCounter struct {
|
||||
Size int64
|
||||
}
|
||||
|
||||
// CopyOption is an option for copying data.
|
||||
type CopyOption func(*copyHandler)
|
||||
|
||||
// IgnoreReaderError is a CopyOption that ignores errors from reader. Copy will continue in such case.
|
||||
func IgnoreReaderError() CopyOption {
|
||||
return func(handler *copyHandler) {
|
||||
handler.onReadError = append(handler.onReadError, func(err error) error {
|
||||
@ -46,6 +52,7 @@ func IgnoreReaderError() CopyOption {
|
||||
}
|
||||
}
|
||||
|
||||
// IgnoreWriterError is a CopyOption that ignores errors from writer. Copy will continue in such case.
|
||||
func IgnoreWriterError() CopyOption {
|
||||
return func(handler *copyHandler) {
|
||||
handler.onWriteError = append(handler.onWriteError, func(err error) error {
|
||||
@ -54,6 +61,7 @@ func IgnoreWriterError() CopyOption {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateActivity is a CopyOption to update activity on each data copy operation.
|
||||
func UpdateActivity(timer signal.ActivityUpdater) CopyOption {
|
||||
return func(handler *copyHandler) {
|
||||
handler.onData = append(handler.onData, func(MultiBuffer) {
|
||||
@ -62,31 +70,34 @@ func UpdateActivity(timer signal.ActivityUpdater) CopyOption {
|
||||
}
|
||||
}
|
||||
|
||||
// CountSize is a CopyOption that sums the total size of data copied into the given SizeCounter.
|
||||
func CountSize(sc *SizeCounter) CopyOption {
|
||||
return func(handler *copyHandler) {
|
||||
handler.onData = append(handler.onData, func(b MultiBuffer) {
|
||||
sc.Size += int64(b.Len())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func copyInternal(reader Reader, writer Writer, handler *copyHandler) error {
|
||||
for {
|
||||
buffer, err := handler.readFrom(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !buffer.IsEmpty() {
|
||||
for _, handler := range handler.onData {
|
||||
handler(buffer)
|
||||
}
|
||||
|
||||
if buffer.IsEmpty() {
|
||||
buffer.Release()
|
||||
continue
|
||||
}
|
||||
|
||||
for _, handler := range handler.onData {
|
||||
handler(buffer)
|
||||
}
|
||||
|
||||
if err := handler.writeTo(writer, buffer); err != nil {
|
||||
buffer.Release()
|
||||
if werr := handler.writeTo(writer, buffer); werr != nil {
|
||||
buffer.Release()
|
||||
return werr
|
||||
}
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Copy dumps all payload from reader to writer or stops when an error occurs.
|
||||
// ActivityTimer gets updated as soon as there is a payload.
|
||||
// Copy dumps all payload from reader to writer or stops when an error occurs. It returns nil when EOF.
|
||||
func Copy(reader Reader, writer Writer, options ...CopyOption) error {
|
||||
handler := new(copyHandler)
|
||||
for _, option := range options {
|
||||
|
@ -5,22 +5,24 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Reader extends io.Reader with alloc.Buffer.
|
||||
// Reader extends io.Reader with MultiBuffer.
|
||||
type Reader interface {
|
||||
// Read reads content from underlying reader, and put it into an alloc.Buffer.
|
||||
Read() (MultiBuffer, error)
|
||||
// ReadMultiBuffer reads content from underlying reader, and put it into a MultiBuffer.
|
||||
ReadMultiBuffer() (MultiBuffer, error)
|
||||
}
|
||||
|
||||
// ErrReadTimeout is an error that happens with IO timeout.
|
||||
var ErrReadTimeout = newError("IO timeout")
|
||||
|
||||
// TimeoutReader is a reader that returns error if Read() operation takes longer than the given timeout.
|
||||
type TimeoutReader interface {
|
||||
ReadTimeout(time.Duration) (MultiBuffer, error)
|
||||
}
|
||||
|
||||
// Writer extends io.Writer with alloc.Buffer.
|
||||
// Writer extends io.Writer with MultiBuffer.
|
||||
type Writer interface {
|
||||
// Write writes an alloc.Buffer into underlying writer.
|
||||
Write(MultiBuffer) error
|
||||
// WriteMultiBuffer writes a MultiBuffer into underlying writer.
|
||||
WriteMultiBuffer(MultiBuffer) error
|
||||
}
|
||||
|
||||
// ReadFrom creates a Supplier to read from a given io.Reader.
|
||||
@ -47,57 +49,21 @@ func ReadAtLeastFrom(reader io.Reader, size int) Supplier {
|
||||
// NewReader creates a new Reader.
|
||||
// The Reader instance doesn't take the ownership of reader.
|
||||
func NewReader(reader io.Reader) Reader {
|
||||
if mr, ok := reader.(MultiBufferReader); ok {
|
||||
return &readerAdpater{
|
||||
MultiBufferReader: mr,
|
||||
}
|
||||
if mr, ok := reader.(Reader); ok {
|
||||
return mr
|
||||
}
|
||||
|
||||
return &BytesToBufferReader{
|
||||
reader: reader,
|
||||
buffer: make([]byte, 32*1024),
|
||||
}
|
||||
}
|
||||
|
||||
func NewMergingReader(reader io.Reader) Reader {
|
||||
return NewMergingReaderSize(reader, 32*1024)
|
||||
}
|
||||
|
||||
func NewMergingReaderSize(reader io.Reader, size uint32) Reader {
|
||||
return &BytesToBufferReader{
|
||||
reader: reader,
|
||||
buffer: make([]byte, size),
|
||||
}
|
||||
}
|
||||
|
||||
// ToBytesReader converts a Reaaer to io.Reader.
|
||||
func ToBytesReader(stream Reader) io.Reader {
|
||||
return &bufferToBytesReader{
|
||||
stream: stream,
|
||||
}
|
||||
return NewBytesToBufferReader(reader)
|
||||
}
|
||||
|
||||
// NewWriter creates a new Writer.
|
||||
func NewWriter(writer io.Writer) Writer {
|
||||
if mw, ok := writer.(MultiBufferWriter); ok {
|
||||
return &writerAdapter{
|
||||
writer: mw,
|
||||
}
|
||||
if mw, ok := writer.(Writer); ok {
|
||||
return mw
|
||||
}
|
||||
|
||||
return &BufferToBytesWriter{
|
||||
writer: writer,
|
||||
}
|
||||
}
|
||||
|
||||
func NewMergingWriter(writer io.Writer) Writer {
|
||||
return NewMergingWriterSize(writer, 4096)
|
||||
}
|
||||
|
||||
func NewMergingWriterSize(writer io.Writer, size uint32) Writer {
|
||||
return &mergingWriter{
|
||||
writer: writer,
|
||||
buffer: make([]byte, size),
|
||||
Writer: writer,
|
||||
}
|
||||
}
|
||||
|
||||
@ -106,10 +72,3 @@ func NewSequentialWriter(writer io.Writer) Writer {
|
||||
writer: writer,
|
||||
}
|
||||
}
|
||||
|
||||
// ToBytesWriter converts a Writer to io.Writer
|
||||
func ToBytesWriter(writer Writer) io.Writer {
|
||||
return &bytesToBufferWriter{
|
||||
writer: writer,
|
||||
}
|
||||
}
|
||||
|
@ -1,21 +1,53 @@
|
||||
package buf
|
||||
|
||||
import "net"
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
type MultiBufferWriter interface {
|
||||
WriteMultiBuffer(MultiBuffer) error
|
||||
"v2ray.com/core/common"
|
||||
"v2ray.com/core/common/errors"
|
||||
)
|
||||
|
||||
// ReadAllToMultiBuffer reads all content from the reader into a MultiBuffer, until EOF.
|
||||
func ReadAllToMultiBuffer(reader io.Reader) (MultiBuffer, error) {
|
||||
mb := NewMultiBufferCap(128)
|
||||
|
||||
for {
|
||||
b := New()
|
||||
err := b.Reset(ReadFrom(reader))
|
||||
if b.IsEmpty() {
|
||||
b.Release()
|
||||
} else {
|
||||
mb.Append(b)
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Cause(err) == io.EOF {
|
||||
return mb, nil
|
||||
}
|
||||
mb.Release()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type MultiBufferReader interface {
|
||||
ReadMultiBuffer() (MultiBuffer, error)
|
||||
// ReadAllToBytes reads all content from the reader into a byte array, until EOF.
|
||||
func ReadAllToBytes(reader io.Reader) ([]byte, error) {
|
||||
mb, err := ReadAllToMultiBuffer(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b := make([]byte, mb.Len())
|
||||
common.Must2(mb.Read(b))
|
||||
mb.Release()
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// MultiBuffer is a list of Buffers. The order of Buffer matters.
|
||||
type MultiBuffer []*Buffer
|
||||
|
||||
// NewMultiBuffer creates a new MultiBuffer instance.
|
||||
func NewMultiBuffer() MultiBuffer {
|
||||
return MultiBuffer(make([]*Buffer, 0, 128))
|
||||
// NewMultiBufferCap creates a new MultiBuffer instance.
|
||||
func NewMultiBufferCap(capacity int) MultiBuffer {
|
||||
return MultiBuffer(make([]*Buffer, 0, capacity))
|
||||
}
|
||||
|
||||
// NewMultiBufferValue wraps a list of Buffers into MultiBuffer.
|
||||
@ -23,14 +55,17 @@ func NewMultiBufferValue(b ...*Buffer) MultiBuffer {
|
||||
return MultiBuffer(b)
|
||||
}
|
||||
|
||||
// Append appends buffer to the end of this MultiBuffer
|
||||
func (mb *MultiBuffer) Append(buf *Buffer) {
|
||||
*mb = append(*mb, buf)
|
||||
}
|
||||
|
||||
// AppendMulti appends a MultiBuffer to the end of this one.
|
||||
func (mb *MultiBuffer) AppendMulti(buf MultiBuffer) {
|
||||
*mb = append(*mb, buf...)
|
||||
}
|
||||
|
||||
// Copy copied the begining part of the MultiBuffer into the given byte array.
|
||||
func (mb MultiBuffer) Copy(b []byte) int {
|
||||
total := 0
|
||||
for _, bb := range mb {
|
||||
@ -43,6 +78,7 @@ func (mb MultiBuffer) Copy(b []byte) int {
|
||||
return total
|
||||
}
|
||||
|
||||
// Read implements io.Reader.
|
||||
func (mb *MultiBuffer) Read(b []byte) (int, error) {
|
||||
endIndex := len(*mb)
|
||||
totalBytes := 0
|
||||
@ -52,6 +88,7 @@ func (mb *MultiBuffer) Read(b []byte) (int, error) {
|
||||
b = b[nBytes:]
|
||||
if bb.IsEmpty() {
|
||||
bb.Release()
|
||||
(*mb)[i] = nil
|
||||
} else {
|
||||
endIndex = i
|
||||
break
|
||||
@ -61,6 +98,7 @@ func (mb *MultiBuffer) Read(b []byte) (int, error) {
|
||||
return totalBytes, nil
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
func (mb *MultiBuffer) Write(b []byte) {
|
||||
n := len(*mb)
|
||||
if n > 0 && !(*mb)[n-1].IsFull() {
|
||||
@ -96,11 +134,12 @@ func (mb MultiBuffer) IsEmpty() bool {
|
||||
}
|
||||
|
||||
// Release releases all Buffers in the MultiBuffer.
|
||||
func (mb MultiBuffer) Release() {
|
||||
for i, b := range mb {
|
||||
func (mb *MultiBuffer) Release() {
|
||||
for i, b := range *mb {
|
||||
b.Release()
|
||||
mb[i] = nil
|
||||
(*mb)[i] = nil
|
||||
}
|
||||
*mb = nil
|
||||
}
|
||||
|
||||
// ToNetBuffers converts this MultiBuffer to net.Buffers. The return net.Buffers points to the same content of the MultiBuffer.
|
||||
@ -112,8 +151,9 @@ func (mb MultiBuffer) ToNetBuffers() net.Buffers {
|
||||
return bs
|
||||
}
|
||||
|
||||
// SliceBySize splits the begining of this MultiBuffer into another one, for at most size bytes.
|
||||
func (mb *MultiBuffer) SliceBySize(size int) MultiBuffer {
|
||||
slice := NewMultiBuffer()
|
||||
slice := NewMultiBufferCap(10)
|
||||
sliceSize := 0
|
||||
endIndex := len(*mb)
|
||||
for i, b := range *mb {
|
||||
@ -123,16 +163,24 @@ func (mb *MultiBuffer) SliceBySize(size int) MultiBuffer {
|
||||
}
|
||||
sliceSize += b.Len()
|
||||
slice.Append(b)
|
||||
(*mb)[i] = nil
|
||||
}
|
||||
*mb = (*mb)[endIndex:]
|
||||
if endIndex == 0 && len(*mb) > 0 {
|
||||
b := New()
|
||||
common.Must(b.Reset(ReadFullFrom((*mb)[0], size)))
|
||||
return NewMultiBufferValue(b)
|
||||
}
|
||||
return slice
|
||||
}
|
||||
|
||||
// SplitFirst splits out the first Buffer in this MultiBuffer.
|
||||
func (mb *MultiBuffer) SplitFirst() *Buffer {
|
||||
if len(*mb) == 0 {
|
||||
return nil
|
||||
}
|
||||
b := (*mb)[0]
|
||||
(*mb)[0] = nil
|
||||
*mb = (*mb)[1:]
|
||||
return b
|
||||
}
|
||||
|
@ -4,11 +4,11 @@ import (
|
||||
"testing"
|
||||
|
||||
. "v2ray.com/core/common/buf"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestMultiBufferRead(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
b1 := New()
|
||||
b1.AppendBytes('a', 'b')
|
||||
@ -19,17 +19,17 @@ func TestMultiBufferRead(t *testing.T) {
|
||||
|
||||
bs := make([]byte, 32)
|
||||
nBytes, err := mb.Read(bs)
|
||||
assert.Error(err).IsNil()
|
||||
assert.Int(nBytes).Equals(4)
|
||||
assert.Bytes(bs[:nBytes]).Equals([]byte("abcd"))
|
||||
assert(err, IsNil)
|
||||
assert(nBytes, Equals, 4)
|
||||
assert(bs[:nBytes], Equals, []byte("abcd"))
|
||||
}
|
||||
|
||||
func TestMultiBufferAppend(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
var mb MultiBuffer
|
||||
b := New()
|
||||
b.AppendBytes('a', 'b')
|
||||
mb.Append(b)
|
||||
assert.Int(mb.Len()).Equals(2)
|
||||
assert(mb.Len(), Equals, 2)
|
||||
}
|
||||
|
@ -6,38 +6,86 @@ import (
|
||||
"v2ray.com/core/common/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
_ Reader = (*BytesToBufferReader)(nil)
|
||||
_ io.Reader = (*BytesToBufferReader)(nil)
|
||||
)
|
||||
|
||||
// BytesToBufferReader is a Reader that adjusts its reading speed automatically.
|
||||
type BytesToBufferReader struct {
|
||||
reader io.Reader
|
||||
io.Reader
|
||||
buffer []byte
|
||||
}
|
||||
|
||||
// Read implements Reader.Read().
|
||||
func (r *BytesToBufferReader) Read() (MultiBuffer, error) {
|
||||
nBytes, err := r.reader.Read(r.buffer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func NewBytesToBufferReader(reader io.Reader) Reader {
|
||||
return &BytesToBufferReader{
|
||||
Reader: reader,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *BytesToBufferReader) readSmall() (MultiBuffer, error) {
|
||||
b := New()
|
||||
err := b.Reset(ReadFrom(r.Reader))
|
||||
if b.IsFull() {
|
||||
r.buffer = make([]byte, 32*1024)
|
||||
}
|
||||
if !b.IsEmpty() {
|
||||
return NewMultiBufferValue(b), nil
|
||||
}
|
||||
b.Release()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// ReadMultiBuffer implements Reader.
|
||||
func (r *BytesToBufferReader) ReadMultiBuffer() (MultiBuffer, error) {
|
||||
if r.buffer == nil {
|
||||
return r.readSmall()
|
||||
}
|
||||
|
||||
mb := NewMultiBuffer()
|
||||
mb.Write(r.buffer[:nBytes])
|
||||
return mb, nil
|
||||
nBytes, err := r.Reader.Read(r.buffer)
|
||||
if nBytes > 0 {
|
||||
mb := NewMultiBufferCap(nBytes/Size + 1)
|
||||
mb.Write(r.buffer[:nBytes])
|
||||
return mb, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
type readerAdpater struct {
|
||||
MultiBufferReader
|
||||
}
|
||||
var (
|
||||
_ Reader = (*BufferedReader)(nil)
|
||||
_ io.Reader = (*BufferedReader)(nil)
|
||||
_ io.ByteReader = (*BufferedReader)(nil)
|
||||
_ io.WriterTo = (*BufferedReader)(nil)
|
||||
)
|
||||
|
||||
func (r *readerAdpater) Read() (MultiBuffer, error) {
|
||||
return r.ReadMultiBuffer()
|
||||
}
|
||||
|
||||
type bufferToBytesReader struct {
|
||||
type BufferedReader struct {
|
||||
stream Reader
|
||||
leftOver MultiBuffer
|
||||
buffered bool
|
||||
}
|
||||
|
||||
func (r *bufferToBytesReader) Read(b []byte) (int, error) {
|
||||
func NewBufferedReader(reader Reader) *BufferedReader {
|
||||
return &BufferedReader{
|
||||
stream: reader,
|
||||
buffered: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *BufferedReader) SetBuffered(f bool) {
|
||||
r.buffered = f
|
||||
}
|
||||
|
||||
func (r *BufferedReader) IsBuffered() bool {
|
||||
return r.buffered
|
||||
}
|
||||
|
||||
func (r *BufferedReader) ReadByte() (byte, error) {
|
||||
var b [1]byte
|
||||
_, err := r.Read(b[:])
|
||||
return b[0], err
|
||||
}
|
||||
|
||||
func (r *BufferedReader) Read(b []byte) (int, error) {
|
||||
if r.leftOver != nil {
|
||||
nBytes, _ := r.leftOver.Read(b)
|
||||
if r.leftOver.IsEmpty() {
|
||||
@ -47,51 +95,75 @@ func (r *bufferToBytesReader) Read(b []byte) (int, error) {
|
||||
return nBytes, nil
|
||||
}
|
||||
|
||||
mb, err := r.stream.Read()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
if !r.buffered {
|
||||
if reader, ok := r.stream.(io.Reader); ok {
|
||||
return reader.Read(b)
|
||||
}
|
||||
}
|
||||
|
||||
nBytes, _ := mb.Read(b)
|
||||
if !mb.IsEmpty() {
|
||||
r.leftOver = mb
|
||||
mb, err := r.stream.ReadMultiBuffer()
|
||||
if mb != nil {
|
||||
nBytes, _ := mb.Read(b)
|
||||
if !mb.IsEmpty() {
|
||||
r.leftOver = mb
|
||||
}
|
||||
return nBytes, err
|
||||
}
|
||||
return nBytes, nil
|
||||
return 0, err
|
||||
}
|
||||
|
||||
func (r *bufferToBytesReader) ReadMultiBuffer() (MultiBuffer, error) {
|
||||
func (r *BufferedReader) ReadMultiBuffer() (MultiBuffer, error) {
|
||||
if r.leftOver != nil {
|
||||
mb := r.leftOver
|
||||
r.leftOver = nil
|
||||
return mb, nil
|
||||
}
|
||||
|
||||
return r.stream.Read()
|
||||
return r.stream.ReadMultiBuffer()
|
||||
}
|
||||
|
||||
func (r *bufferToBytesReader) writeToInternal(writer io.Writer) (int64, error) {
|
||||
// ReadAtMost returns a MultiBuffer with at most size.
|
||||
func (r *BufferedReader) ReadAtMost(size int) (MultiBuffer, error) {
|
||||
if r.leftOver == nil {
|
||||
mb, err := r.stream.ReadMultiBuffer()
|
||||
if mb.IsEmpty() && err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.leftOver = mb
|
||||
}
|
||||
|
||||
mb := r.leftOver.SliceBySize(size)
|
||||
if r.leftOver.IsEmpty() {
|
||||
r.leftOver = nil
|
||||
}
|
||||
return mb, nil
|
||||
}
|
||||
|
||||
func (r *BufferedReader) writeToInternal(writer io.Writer) (int64, error) {
|
||||
mbWriter := NewWriter(writer)
|
||||
totalBytes := int64(0)
|
||||
if r.leftOver != nil {
|
||||
if err := mbWriter.Write(r.leftOver); err != nil {
|
||||
totalBytes += int64(r.leftOver.Len())
|
||||
if err := mbWriter.WriteMultiBuffer(r.leftOver); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
totalBytes += int64(r.leftOver.Len())
|
||||
}
|
||||
|
||||
for {
|
||||
mb, err := r.stream.Read()
|
||||
if err != nil {
|
||||
return totalBytes, err
|
||||
mb, err := r.stream.ReadMultiBuffer()
|
||||
if mb != nil {
|
||||
totalBytes += int64(mb.Len())
|
||||
if werr := mbWriter.WriteMultiBuffer(mb); werr != nil {
|
||||
return totalBytes, err
|
||||
}
|
||||
}
|
||||
totalBytes += int64(mb.Len())
|
||||
if err := mbWriter.Write(mb); err != nil {
|
||||
if err != nil {
|
||||
return totalBytes, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *bufferToBytesReader) WriteTo(writer io.Writer) (int64, error) {
|
||||
func (r *BufferedReader) WriteTo(writer io.Writer) (int64, error) {
|
||||
nBytes, err := r.writeToInternal(writer)
|
||||
if errors.Cause(err) == io.EOF {
|
||||
return nBytes, nil
|
||||
|
@ -7,64 +7,66 @@ import (
|
||||
"testing"
|
||||
|
||||
. "v2ray.com/core/common/buf"
|
||||
"v2ray.com/core/testing/assert"
|
||||
"v2ray.com/core/transport/ray"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestAdaptiveReader(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
rawContent := make([]byte, 1024*1024)
|
||||
buffer := bytes.NewBuffer(rawContent)
|
||||
reader := NewReader(bytes.NewReader(make([]byte, 1024*1024)))
|
||||
b, err := reader.ReadMultiBuffer()
|
||||
assert(err, IsNil)
|
||||
assert(b.Len(), Equals, 2*1024)
|
||||
|
||||
reader := NewReader(buffer)
|
||||
b, err := reader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.Int(b.Len()).Equals(32 * 1024)
|
||||
b, err = reader.ReadMultiBuffer()
|
||||
assert(err, IsNil)
|
||||
assert(b.Len(), Equals, 32*1024)
|
||||
}
|
||||
|
||||
func TestBytesReaderWriteTo(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
stream := ray.NewStream(context.Background())
|
||||
reader := ToBytesReader(stream)
|
||||
reader := NewBufferedReader(stream)
|
||||
b1 := New()
|
||||
b1.AppendBytes('a', 'b', 'c')
|
||||
b2 := New()
|
||||
b2.AppendBytes('e', 'f', 'g')
|
||||
assert.Error(stream.Write(NewMultiBufferValue(b1, b2))).IsNil()
|
||||
assert(stream.WriteMultiBuffer(NewMultiBufferValue(b1, b2)), IsNil)
|
||||
stream.Close()
|
||||
|
||||
stream2 := ray.NewStream(context.Background())
|
||||
writer := ToBytesWriter(stream2)
|
||||
writer := NewBufferedWriter(stream2)
|
||||
writer.SetBuffered(false)
|
||||
|
||||
nBytes, err := io.Copy(writer, reader)
|
||||
assert.Error(err).IsNil()
|
||||
assert.Int64(nBytes).Equals(6)
|
||||
assert(err, IsNil)
|
||||
assert(nBytes, Equals, int64(6))
|
||||
|
||||
mb, err := stream2.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.Int(len(mb)).Equals(2)
|
||||
assert.String(mb[0].String()).Equals("abc")
|
||||
assert.String(mb[1].String()).Equals("efg")
|
||||
mb, err := stream2.ReadMultiBuffer()
|
||||
assert(err, IsNil)
|
||||
assert(len(mb), Equals, 2)
|
||||
assert(mb[0].String(), Equals, "abc")
|
||||
assert(mb[1].String(), Equals, "efg")
|
||||
}
|
||||
|
||||
func TestBytesReaderMultiBuffer(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
stream := ray.NewStream(context.Background())
|
||||
reader := ToBytesReader(stream)
|
||||
reader := NewBufferedReader(stream)
|
||||
b1 := New()
|
||||
b1.AppendBytes('a', 'b', 'c')
|
||||
b2 := New()
|
||||
b2.AppendBytes('e', 'f', 'g')
|
||||
assert.Error(stream.Write(NewMultiBufferValue(b1, b2))).IsNil()
|
||||
assert(stream.WriteMultiBuffer(NewMultiBufferValue(b1, b2)), IsNil)
|
||||
stream.Close()
|
||||
|
||||
mbReader := NewReader(reader)
|
||||
mb, err := mbReader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.Int(len(mb)).Equals(2)
|
||||
assert.String(mb[0].String()).Equals("abc")
|
||||
assert.String(mb[1].String()).Equals("efg")
|
||||
mb, err := mbReader.ReadMultiBuffer()
|
||||
assert(err, IsNil)
|
||||
assert(len(mb), Equals, 2)
|
||||
assert(mb[0].String(), Equals, "abc")
|
||||
assert(mb[1].String(), Equals, "efg")
|
||||
}
|
||||
|
@ -1,52 +1,164 @@
|
||||
package buf
|
||||
|
||||
import "io"
|
||||
import (
|
||||
"io"
|
||||
|
||||
"v2ray.com/core/common/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
_ io.ReaderFrom = (*BufferToBytesWriter)(nil)
|
||||
_ io.Writer = (*BufferToBytesWriter)(nil)
|
||||
_ Writer = (*BufferToBytesWriter)(nil)
|
||||
)
|
||||
|
||||
// BufferToBytesWriter is a Writer that writes alloc.Buffer into underlying writer.
|
||||
type BufferToBytesWriter struct {
|
||||
writer io.Writer
|
||||
io.Writer
|
||||
}
|
||||
|
||||
// Write implements Writer.Write(). Write() takes ownership of the given buffer.
|
||||
func (w *BufferToBytesWriter) Write(mb MultiBuffer) error {
|
||||
func NewBufferToBytesWriter(writer io.Writer) *BufferToBytesWriter {
|
||||
return &BufferToBytesWriter{
|
||||
Writer: writer,
|
||||
}
|
||||
}
|
||||
|
||||
// WriteMultiBuffer implements Writer. This method takes ownership of the given buffer.
|
||||
func (w *BufferToBytesWriter) WriteMultiBuffer(mb MultiBuffer) error {
|
||||
defer mb.Release()
|
||||
|
||||
bs := mb.ToNetBuffers()
|
||||
_, err := bs.WriteTo(w.writer)
|
||||
_, err := bs.WriteTo(w)
|
||||
return err
|
||||
}
|
||||
|
||||
type writerAdapter struct {
|
||||
writer MultiBufferWriter
|
||||
// ReadFrom implements io.ReaderFrom.
|
||||
func (w *BufferToBytesWriter) ReadFrom(reader io.Reader) (int64, error) {
|
||||
var sc SizeCounter
|
||||
err := Copy(NewReader(reader), w, CountSize(&sc))
|
||||
return sc.Size, err
|
||||
}
|
||||
|
||||
// Write implements buf.MultiBufferWriter.
|
||||
func (w *writerAdapter) Write(mb MultiBuffer) error {
|
||||
return w.writer.WriteMultiBuffer(mb)
|
||||
var (
|
||||
_ io.ReaderFrom = (*BufferedWriter)(nil)
|
||||
_ io.Writer = (*BufferedWriter)(nil)
|
||||
_ Writer = (*BufferedWriter)(nil)
|
||||
_ io.ByteWriter = (*BufferedWriter)(nil)
|
||||
)
|
||||
|
||||
// BufferedWriter is a Writer with internal buffer.
|
||||
type BufferedWriter struct {
|
||||
writer Writer
|
||||
buffer *Buffer
|
||||
buffered bool
|
||||
}
|
||||
|
||||
type mergingWriter struct {
|
||||
writer io.Writer
|
||||
buffer []byte
|
||||
// NewBufferedWriter creates a new BufferedWriter.
|
||||
func NewBufferedWriter(writer Writer) *BufferedWriter {
|
||||
return &BufferedWriter{
|
||||
writer: writer,
|
||||
buffer: New(),
|
||||
buffered: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *mergingWriter) Write(mb MultiBuffer) error {
|
||||
defer mb.Release()
|
||||
func (w *BufferedWriter) WriteByte(c byte) error {
|
||||
_, err := w.Write([]byte{c})
|
||||
return err
|
||||
}
|
||||
|
||||
for !mb.IsEmpty() {
|
||||
nBytes, _ := mb.Read(w.buffer)
|
||||
if _, err := w.writer.Write(w.buffer[:nBytes]); err != nil {
|
||||
// Write implements io.Writer.
|
||||
func (w *BufferedWriter) Write(b []byte) (int, error) {
|
||||
if !w.buffered {
|
||||
if writer, ok := w.writer.(io.Writer); ok {
|
||||
return writer.Write(b)
|
||||
}
|
||||
}
|
||||
|
||||
totalBytes := 0
|
||||
for len(b) > 0 {
|
||||
if w.buffer == nil {
|
||||
w.buffer = New()
|
||||
}
|
||||
|
||||
nBytes, err := w.buffer.Write(b)
|
||||
totalBytes += nBytes
|
||||
if err != nil {
|
||||
return totalBytes, err
|
||||
}
|
||||
if !w.buffered || w.buffer.IsFull() {
|
||||
if err := w.Flush(); err != nil {
|
||||
return totalBytes, err
|
||||
}
|
||||
}
|
||||
b = b[nBytes:]
|
||||
}
|
||||
|
||||
return totalBytes, nil
|
||||
}
|
||||
|
||||
// WriteMultiBuffer implements Writer. It takes ownership of the given MultiBuffer.
|
||||
func (w *BufferedWriter) WriteMultiBuffer(b MultiBuffer) error {
|
||||
if !w.buffered {
|
||||
return w.writer.WriteMultiBuffer(b)
|
||||
}
|
||||
|
||||
defer b.Release()
|
||||
|
||||
for !b.IsEmpty() {
|
||||
if err := w.buffer.AppendSupplier(ReadFrom(&b)); err != nil {
|
||||
return err
|
||||
}
|
||||
if w.buffer.IsFull() {
|
||||
if err := w.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush flushes buffered content into underlying writer.
|
||||
func (w *BufferedWriter) Flush() error {
|
||||
if !w.buffer.IsEmpty() {
|
||||
if err := w.writer.WriteMultiBuffer(NewMultiBufferValue(w.buffer)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if w.buffered {
|
||||
w.buffer = New()
|
||||
} else {
|
||||
w.buffer = nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *BufferedWriter) SetBuffered(f bool) error {
|
||||
w.buffered = f
|
||||
if !f {
|
||||
return w.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadFrom implements io.ReaderFrom.
|
||||
func (w *BufferedWriter) ReadFrom(reader io.Reader) (int64, error) {
|
||||
if err := w.SetBuffered(false); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var sc SizeCounter
|
||||
err := Copy(NewReader(reader), w, CountSize(&sc))
|
||||
return sc.Size, err
|
||||
}
|
||||
|
||||
type seqWriter struct {
|
||||
writer io.Writer
|
||||
}
|
||||
|
||||
func (w *seqWriter) Write(mb MultiBuffer) error {
|
||||
func (w *seqWriter) WriteMultiBuffer(mb MultiBuffer) error {
|
||||
defer mb.Release()
|
||||
|
||||
for _, b := range mb {
|
||||
@ -61,54 +173,38 @@ func (w *seqWriter) Write(mb MultiBuffer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
_ MultiBufferWriter = (*bytesToBufferWriter)(nil)
|
||||
)
|
||||
|
||||
type bytesToBufferWriter struct {
|
||||
writer Writer
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
func (w *bytesToBufferWriter) Write(payload []byte) (int, error) {
|
||||
mb := NewMultiBuffer()
|
||||
mb.Write(payload)
|
||||
if err := w.writer.Write(mb); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(payload), nil
|
||||
}
|
||||
|
||||
func (w *bytesToBufferWriter) WriteMultiBuffer(mb MultiBuffer) error {
|
||||
return w.writer.Write(mb)
|
||||
}
|
||||
|
||||
func (w *bytesToBufferWriter) ReadFrom(reader io.Reader) (int64, error) {
|
||||
mbReader := NewReader(reader)
|
||||
totalBytes := int64(0)
|
||||
eof := false
|
||||
for !eof {
|
||||
mb, err := mbReader.Read()
|
||||
if err == io.EOF {
|
||||
eof = true
|
||||
} else if err != nil {
|
||||
return totalBytes, err
|
||||
}
|
||||
totalBytes += int64(mb.Len())
|
||||
if err := w.writer.Write(mb); err != nil {
|
||||
return totalBytes, err
|
||||
}
|
||||
}
|
||||
return totalBytes, nil
|
||||
}
|
||||
|
||||
type noOpWriter struct{}
|
||||
|
||||
func (noOpWriter) Write(b MultiBuffer) error {
|
||||
func (noOpWriter) WriteMultiBuffer(b MultiBuffer) error {
|
||||
b.Release()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (noOpWriter) Write(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (noOpWriter) ReadFrom(reader io.Reader) (int64, error) {
|
||||
b := New()
|
||||
defer b.Release()
|
||||
|
||||
totalBytes := int64(0)
|
||||
for {
|
||||
err := b.Reset(ReadFrom(reader))
|
||||
totalBytes += int64(b.Len())
|
||||
if err != nil {
|
||||
if errors.Cause(err) == io.EOF {
|
||||
return totalBytes, nil
|
||||
}
|
||||
return totalBytes, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
// Discard is a Writer that swallows all contents written in.
|
||||
Discard Writer = noOpWriter{}
|
||||
|
||||
// DiscardBytes is an io.Writer that swallows all contents written in.
|
||||
DiscardBytes io.Writer = noOpWriter{}
|
||||
)
|
||||
|
@ -9,37 +9,67 @@ import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"v2ray.com/core/common"
|
||||
. "v2ray.com/core/common/buf"
|
||||
"v2ray.com/core/testing/assert"
|
||||
"v2ray.com/core/transport/ray"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestWriter(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
lb := New()
|
||||
assert.Error(lb.AppendSupplier(ReadFrom(rand.Reader))).IsNil()
|
||||
assert(lb.AppendSupplier(ReadFrom(rand.Reader)), IsNil)
|
||||
|
||||
expectedBytes := append([]byte(nil), lb.Bytes()...)
|
||||
|
||||
writeBuffer := bytes.NewBuffer(make([]byte, 0, 1024*1024))
|
||||
|
||||
writer := NewWriter(NewBufferedWriter(writeBuffer))
|
||||
err := writer.Write(NewMultiBufferValue(lb))
|
||||
assert.Error(err).IsNil()
|
||||
assert.Bytes(expectedBytes).Equals(writeBuffer.Bytes())
|
||||
writer := NewBufferedWriter(NewWriter(writeBuffer))
|
||||
writer.SetBuffered(false)
|
||||
err := writer.WriteMultiBuffer(NewMultiBufferValue(lb))
|
||||
assert(err, IsNil)
|
||||
assert(writer.Flush(), IsNil)
|
||||
assert(expectedBytes, Equals, writeBuffer.Bytes())
|
||||
}
|
||||
|
||||
func TestBytesWriterReadFrom(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
cache := ray.NewStream(context.Background())
|
||||
reader := bufio.NewReader(io.LimitReader(rand.Reader, 8192))
|
||||
_, err := reader.WriteTo(ToBytesWriter(cache))
|
||||
assert.Error(err).IsNil()
|
||||
const size = 50000
|
||||
reader := bufio.NewReader(io.LimitReader(rand.Reader, size))
|
||||
writer := NewBufferedWriter(cache)
|
||||
writer.SetBuffered(false)
|
||||
nBytes, err := reader.WriteTo(writer)
|
||||
assert(nBytes, Equals, int64(size))
|
||||
assert(err, IsNil)
|
||||
|
||||
mb, err := cache.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.Int(mb.Len()).Equals(8192)
|
||||
assert.Int(len(mb)).Equals(4)
|
||||
mb, err := cache.ReadMultiBuffer()
|
||||
assert(err, IsNil)
|
||||
assert(mb.Len(), Equals, size)
|
||||
}
|
||||
|
||||
func TestDiscardBytes(t *testing.T) {
|
||||
assert := With(t)
|
||||
|
||||
b := New()
|
||||
common.Must(b.Reset(ReadFullFrom(rand.Reader, Size)))
|
||||
|
||||
nBytes, err := io.Copy(DiscardBytes, b)
|
||||
assert(nBytes, Equals, int64(Size))
|
||||
assert(err, IsNil)
|
||||
}
|
||||
|
||||
func TestDiscardBytesMultiBuffer(t *testing.T) {
|
||||
assert := With(t)
|
||||
|
||||
const size = 10240*1024 + 1
|
||||
buffer := bytes.NewBuffer(make([]byte, 0, size))
|
||||
common.Must2(buffer.ReadFrom(io.LimitReader(rand.Reader, size)))
|
||||
|
||||
r := NewReader(buffer)
|
||||
nBytes, err := io.Copy(DiscardBytes, NewBufferedReader(r))
|
||||
assert(nBytes, Equals, int64(size))
|
||||
assert(err, IsNil)
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ func Must(err error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Must2 panics if the second parameter is not nil.
|
||||
func Must2(v interface{}, err error) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
@ -29,6 +29,26 @@ func (v StaticBytesGenerator) Next() []byte {
|
||||
return v.Content
|
||||
}
|
||||
|
||||
type IncreasingAEADNonceGenerator struct {
|
||||
nonce []byte
|
||||
}
|
||||
|
||||
func NewIncreasingAEADNonceGenerator() *IncreasingAEADNonceGenerator {
|
||||
return &IncreasingAEADNonceGenerator{
|
||||
nonce: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF},
|
||||
}
|
||||
}
|
||||
|
||||
func (g *IncreasingAEADNonceGenerator) Next() []byte {
|
||||
for i := range g.nonce {
|
||||
g.nonce[i]++
|
||||
if g.nonce[i] != 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return g.nonce
|
||||
}
|
||||
|
||||
type Authenticator interface {
|
||||
NonceSize() int
|
||||
Overhead() int
|
||||
@ -48,7 +68,10 @@ func (v *AEADAuthenticator) Open(dst, cipherText []byte) ([]byte, error) {
|
||||
return nil, newError("invalid AEAD nonce size: ", len(iv))
|
||||
}
|
||||
|
||||
additionalData := v.AdditionalDataGenerator.Next()
|
||||
var additionalData []byte
|
||||
if v.AdditionalDataGenerator != nil {
|
||||
additionalData = v.AdditionalDataGenerator.Next()
|
||||
}
|
||||
return v.AEAD.Open(dst, iv, cipherText, additionalData)
|
||||
}
|
||||
|
||||
@ -58,7 +81,10 @@ func (v *AEADAuthenticator) Seal(dst, plainText []byte) ([]byte, error) {
|
||||
return nil, newError("invalid AEAD nonce size: ", len(iv))
|
||||
}
|
||||
|
||||
additionalData := v.AdditionalDataGenerator.Next()
|
||||
var additionalData []byte
|
||||
if v.AdditionalDataGenerator != nil {
|
||||
additionalData = v.AdditionalDataGenerator.Next()
|
||||
}
|
||||
return v.AEAD.Seal(dst, iv, plainText, additionalData), nil
|
||||
}
|
||||
|
||||
@ -93,7 +119,12 @@ func (r *AuthenticationReader) readSize() error {
|
||||
|
||||
sizeBytes := r.sizeParser.SizeBytes()
|
||||
if r.buffer.Len() < sizeBytes {
|
||||
r.buffer.Reset(buf.ReadFrom(r.buffer))
|
||||
if r.buffer.IsEmpty() {
|
||||
r.buffer.Clear()
|
||||
} else {
|
||||
common.Must(r.buffer.Reset(buf.ReadFrom(r.buffer)))
|
||||
}
|
||||
|
||||
delta := sizeBytes - r.buffer.Len()
|
||||
if err := r.buffer.AppendSupplier(buf.ReadAtLeastFrom(r.reader, delta)); err != nil {
|
||||
return err
|
||||
@ -146,18 +177,18 @@ func (r *AuthenticationReader) readChunk(waitForData bool) ([]byte, error) {
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) {
|
||||
func (r *AuthenticationReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
|
||||
b, err := r.readChunk(true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mb := buf.NewMultiBuffer()
|
||||
var mb buf.MultiBuffer
|
||||
if r.transferType == protocol.TransferTypeStream {
|
||||
mb.Write(b)
|
||||
} else {
|
||||
var bb *buf.Buffer
|
||||
if len(b) < buf.Size {
|
||||
if len(b) <= buf.Size {
|
||||
bb = buf.New()
|
||||
} else {
|
||||
bb = buf.NewLocal(len(b))
|
||||
@ -175,7 +206,7 @@ func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) {
|
||||
mb.Write(b)
|
||||
} else {
|
||||
var bb *buf.Buffer
|
||||
if len(b) < buf.Size {
|
||||
if len(b) <= buf.Size {
|
||||
bb = buf.New()
|
||||
} else {
|
||||
bb = buf.NewLocal(len(b))
|
||||
@ -190,79 +221,92 @@ func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) {
|
||||
|
||||
type AuthenticationWriter struct {
|
||||
auth Authenticator
|
||||
buffer []byte
|
||||
payload []byte
|
||||
writer *buf.BufferedWriter
|
||||
writer buf.Writer
|
||||
sizeParser ChunkSizeEncoder
|
||||
transferType protocol.TransferType
|
||||
}
|
||||
|
||||
func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, writer io.Writer, transferType protocol.TransferType) *AuthenticationWriter {
|
||||
const payloadSize = 1024
|
||||
return &AuthenticationWriter{
|
||||
auth: auth,
|
||||
buffer: make([]byte, payloadSize+sizeParser.SizeBytes()+auth.Overhead()),
|
||||
payload: make([]byte, payloadSize),
|
||||
writer: buf.NewBufferedWriterSize(writer, readerBufferSize),
|
||||
writer: buf.NewWriter(writer),
|
||||
sizeParser: sizeParser,
|
||||
transferType: transferType,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *AuthenticationWriter) append(b []byte) error {
|
||||
encryptedSize := len(b) + w.auth.Overhead()
|
||||
buffer := w.sizeParser.Encode(uint16(encryptedSize), w.buffer[:0])
|
||||
func (w *AuthenticationWriter) seal(b *buf.Buffer) (*buf.Buffer, error) {
|
||||
encryptedSize := b.Len() + w.auth.Overhead()
|
||||
|
||||
buffer, err := w.auth.Seal(buffer, b)
|
||||
if err != nil {
|
||||
return err
|
||||
eb := buf.New()
|
||||
common.Must(eb.Reset(func(bb []byte) (int, error) {
|
||||
w.sizeParser.Encode(uint16(encryptedSize), bb[:0])
|
||||
return w.sizeParser.SizeBytes(), nil
|
||||
}))
|
||||
if err := eb.AppendSupplier(func(bb []byte) (int, error) {
|
||||
_, err := w.auth.Seal(bb[:0], b.Bytes())
|
||||
return encryptedSize, err
|
||||
}); err != nil {
|
||||
eb.Release()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := w.writer.Write(buffer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return eb, nil
|
||||
}
|
||||
|
||||
func (w *AuthenticationWriter) writeStream(mb buf.MultiBuffer) error {
|
||||
defer mb.Release()
|
||||
|
||||
payloadSize := buf.Size - w.auth.Overhead() - w.sizeParser.SizeBytes()
|
||||
mb2Write := buf.NewMultiBufferCap(len(mb) + 10)
|
||||
|
||||
for {
|
||||
n, _ := mb.Read(w.payload)
|
||||
if err := w.append(w.payload[:n]); err != nil {
|
||||
b := buf.New()
|
||||
common.Must(b.Reset(func(bb []byte) (int, error) {
|
||||
return mb.Read(bb[:payloadSize])
|
||||
}))
|
||||
eb, err := w.seal(b)
|
||||
b.Release()
|
||||
|
||||
if err != nil {
|
||||
mb2Write.Release()
|
||||
return err
|
||||
}
|
||||
mb2Write.Append(eb)
|
||||
if mb.IsEmpty() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return w.writer.Flush()
|
||||
return w.writer.WriteMultiBuffer(mb2Write)
|
||||
}
|
||||
|
||||
func (w *AuthenticationWriter) writePacket(mb buf.MultiBuffer) error {
|
||||
defer mb.Release()
|
||||
|
||||
mb2Write := buf.NewMultiBufferCap(len(mb) * 2)
|
||||
|
||||
for {
|
||||
b := mb.SplitFirst()
|
||||
if b == nil {
|
||||
b = buf.New()
|
||||
}
|
||||
if err := w.append(b.Bytes()); err != nil {
|
||||
b.Release()
|
||||
eb, err := w.seal(b)
|
||||
b.Release()
|
||||
if err != nil {
|
||||
mb2Write.Release()
|
||||
return err
|
||||
}
|
||||
b.Release()
|
||||
mb2Write.Append(eb)
|
||||
if mb.IsEmpty() {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return w.writer.Flush()
|
||||
return w.writer.WriteMultiBuffer(mb2Write)
|
||||
}
|
||||
|
||||
func (w *AuthenticationWriter) Write(mb buf.MultiBuffer) error {
|
||||
func (w *AuthenticationWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
||||
if w.transferType == protocol.TransferTypeStream {
|
||||
return w.writeStream(mb)
|
||||
}
|
||||
|
@ -10,19 +10,19 @@ import (
|
||||
"v2ray.com/core/common/buf"
|
||||
. "v2ray.com/core/common/crypto"
|
||||
"v2ray.com/core/common/protocol"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestAuthenticationReaderWriter(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
key := make([]byte, 16)
|
||||
rand.Read(key)
|
||||
block, err := aes.NewCipher(key)
|
||||
assert.Error(err).IsNil()
|
||||
assert(err, IsNil)
|
||||
|
||||
aead, err := cipher.NewGCM(block)
|
||||
assert.Error(err).IsNil()
|
||||
assert(err, IsNil)
|
||||
|
||||
rawPayload := make([]byte, 8192*10)
|
||||
rand.Read(rawPayload)
|
||||
@ -42,10 +42,10 @@ func TestAuthenticationReaderWriter(t *testing.T) {
|
||||
AdditionalDataGenerator: &NoOpBytesGenerator{},
|
||||
}, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream)
|
||||
|
||||
assert.Error(writer.Write(buf.NewMultiBufferValue(payload))).IsNil()
|
||||
assert.Int(cache.Len()).Equals(83360)
|
||||
assert.Error(writer.Write(buf.NewMultiBuffer())).IsNil()
|
||||
assert.Error(err).IsNil()
|
||||
assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(payload)), IsNil)
|
||||
assert(cache.Len(), Equals, 82658)
|
||||
assert(writer.WriteMultiBuffer(buf.MultiBuffer{}), IsNil)
|
||||
assert(err, IsNil)
|
||||
|
||||
reader := NewAuthenticationReader(&AEADAuthenticator{
|
||||
AEAD: aead,
|
||||
@ -55,33 +55,33 @@ func TestAuthenticationReaderWriter(t *testing.T) {
|
||||
AdditionalDataGenerator: &NoOpBytesGenerator{},
|
||||
}, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream)
|
||||
|
||||
mb := buf.NewMultiBuffer()
|
||||
var mb buf.MultiBuffer
|
||||
|
||||
for mb.Len() < len(rawPayload) {
|
||||
mb2, err := reader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
mb2, err := reader.ReadMultiBuffer()
|
||||
assert(err, IsNil)
|
||||
|
||||
mb.AppendMulti(mb2)
|
||||
}
|
||||
|
||||
mbContent := make([]byte, 8192*10)
|
||||
mb.Read(mbContent)
|
||||
assert.Bytes(mbContent).Equals(rawPayload)
|
||||
assert(mbContent, Equals, rawPayload)
|
||||
|
||||
_, err = reader.Read()
|
||||
assert.Error(err).Equals(io.EOF)
|
||||
_, err = reader.ReadMultiBuffer()
|
||||
assert(err, Equals, io.EOF)
|
||||
}
|
||||
|
||||
func TestAuthenticationReaderWriterPacket(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
key := make([]byte, 16)
|
||||
rand.Read(key)
|
||||
block, err := aes.NewCipher(key)
|
||||
assert.Error(err).IsNil()
|
||||
assert(err, IsNil)
|
||||
|
||||
aead, err := cipher.NewGCM(block)
|
||||
assert.Error(err).IsNil()
|
||||
assert(err, IsNil)
|
||||
|
||||
cache := buf.NewLocal(1024)
|
||||
iv := make([]byte, 12)
|
||||
@ -95,7 +95,7 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) {
|
||||
AdditionalDataGenerator: &NoOpBytesGenerator{},
|
||||
}, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket)
|
||||
|
||||
payload := buf.NewMultiBuffer()
|
||||
var payload buf.MultiBuffer
|
||||
pb1 := buf.New()
|
||||
pb1.Append([]byte("abcd"))
|
||||
payload.Append(pb1)
|
||||
@ -104,10 +104,10 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) {
|
||||
pb2.Append([]byte("efgh"))
|
||||
payload.Append(pb2)
|
||||
|
||||
assert.Error(writer.Write(payload)).IsNil()
|
||||
assert.Int(cache.Len()).GreaterThan(0)
|
||||
assert.Error(writer.Write(buf.NewMultiBuffer())).IsNil()
|
||||
assert.Error(err).IsNil()
|
||||
assert(writer.WriteMultiBuffer(payload), IsNil)
|
||||
assert(cache.Len(), GreaterThan, 0)
|
||||
assert(writer.WriteMultiBuffer(buf.MultiBuffer{}), IsNil)
|
||||
assert(err, IsNil)
|
||||
|
||||
reader := NewAuthenticationReader(&AEADAuthenticator{
|
||||
AEAD: aead,
|
||||
@ -117,15 +117,15 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) {
|
||||
AdditionalDataGenerator: &NoOpBytesGenerator{},
|
||||
}, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket)
|
||||
|
||||
mb, err := reader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
mb, err := reader.ReadMultiBuffer()
|
||||
assert(err, IsNil)
|
||||
|
||||
b1 := mb.SplitFirst()
|
||||
assert.String(b1.String()).Equals("abcd")
|
||||
assert(b1.String(), Equals, "abcd")
|
||||
b2 := mb.SplitFirst()
|
||||
assert.String(b2.String()).Equals("efgh")
|
||||
assert.Bool(mb.IsEmpty()).IsTrue()
|
||||
assert(b2.String(), Equals, "efgh")
|
||||
assert(mb.IsEmpty(), IsTrue)
|
||||
|
||||
_, err = reader.Read()
|
||||
assert.Error(err).Equals(io.EOF)
|
||||
_, err = reader.ReadMultiBuffer()
|
||||
assert(err, Equals, io.EOF)
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
|
||||
"v2ray.com/core/common"
|
||||
. "v2ray.com/core/common/crypto"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func mustDecodeHex(s string) []byte {
|
||||
@ -17,7 +17,7 @@ func mustDecodeHex(s string) []byte {
|
||||
}
|
||||
|
||||
func TestChaCha20Stream(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
var cases = []struct {
|
||||
key []byte
|
||||
@ -51,12 +51,12 @@ func TestChaCha20Stream(t *testing.T) {
|
||||
input := make([]byte, len(c.output))
|
||||
actualOutout := make([]byte, len(c.output))
|
||||
s.XORKeyStream(actualOutout, input)
|
||||
assert.Bytes(c.output).Equals(actualOutout)
|
||||
assert(c.output, Equals, actualOutout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChaCha20Decoding(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
key := make([]byte, 32)
|
||||
rand.Read(key)
|
||||
@ -72,5 +72,5 @@ func TestChaCha20Decoding(t *testing.T) {
|
||||
|
||||
stream2 := NewChaCha20Stream(key, iv)
|
||||
stream2.XORKeyStream(x, x)
|
||||
assert.Bytes(x).Equals(payload)
|
||||
assert(x, Equals, payload)
|
||||
}
|
||||
|
@ -3,15 +3,18 @@ package crypto
|
||||
import (
|
||||
"io"
|
||||
|
||||
"v2ray.com/core/common"
|
||||
"v2ray.com/core/common/buf"
|
||||
"v2ray.com/core/common/serial"
|
||||
)
|
||||
|
||||
// ChunkSizeDecoder is an utility class to decode size value from bytes.
|
||||
type ChunkSizeDecoder interface {
|
||||
SizeBytes() int
|
||||
Decode([]byte) (uint16, error)
|
||||
}
|
||||
|
||||
// ChunkSizeEncoder is an utility class to encode size value into bytes.
|
||||
type ChunkSizeEncoder interface {
|
||||
SizeBytes() int
|
||||
Encode(uint16, []byte) []byte
|
||||
@ -31,50 +34,53 @@ func (PlainChunkSizeParser) Decode(b []byte) (uint16, error) {
|
||||
return serial.BytesToUint16(b), nil
|
||||
}
|
||||
|
||||
type AEADChunkSizeParser struct {
|
||||
Auth *AEADAuthenticator
|
||||
}
|
||||
|
||||
func (p *AEADChunkSizeParser) SizeBytes() int {
|
||||
return 2 + p.Auth.Overhead()
|
||||
}
|
||||
|
||||
func (p *AEADChunkSizeParser) Encode(size uint16, b []byte) []byte {
|
||||
b = serial.Uint16ToBytes(size-uint16(p.Auth.Overhead()), b)
|
||||
b, err := p.Auth.Seal(b[:0], b)
|
||||
common.Must(err)
|
||||
return b
|
||||
}
|
||||
|
||||
func (p *AEADChunkSizeParser) Decode(b []byte) (uint16, error) {
|
||||
b, err := p.Auth.Open(b[:0], b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return serial.BytesToUint16(b) + uint16(p.Auth.Overhead()), nil
|
||||
}
|
||||
|
||||
type ChunkStreamReader struct {
|
||||
sizeDecoder ChunkSizeDecoder
|
||||
reader buf.Reader
|
||||
reader *buf.BufferedReader
|
||||
|
||||
buffer []byte
|
||||
leftOver buf.MultiBuffer
|
||||
leftOverSize int
|
||||
}
|
||||
|
||||
func NewChunkStreamReader(sizeDecoder ChunkSizeDecoder, reader io.Reader) *ChunkStreamReader {
|
||||
return &ChunkStreamReader{
|
||||
sizeDecoder: sizeDecoder,
|
||||
reader: buf.NewReader(reader),
|
||||
reader: buf.NewBufferedReader(buf.NewReader(reader)),
|
||||
buffer: make([]byte, sizeDecoder.SizeBytes()),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ChunkStreamReader) readAtLeast(size int) error {
|
||||
mb := r.leftOver
|
||||
r.leftOver = nil
|
||||
for mb.Len() < size {
|
||||
extra, err := r.reader.Read()
|
||||
if err != nil {
|
||||
mb.Release()
|
||||
return err
|
||||
}
|
||||
mb.AppendMulti(extra)
|
||||
}
|
||||
r.leftOver = mb
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *ChunkStreamReader) readSize() (uint16, error) {
|
||||
if r.sizeDecoder.SizeBytes() > r.leftOver.Len() {
|
||||
if err := r.readAtLeast(r.sizeDecoder.SizeBytes() - r.leftOver.Len()); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if _, err := io.ReadFull(r.reader, r.buffer); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
r.leftOver.Read(r.buffer)
|
||||
return r.sizeDecoder.Decode(r.buffer)
|
||||
}
|
||||
|
||||
func (r *ChunkStreamReader) Read() (buf.MultiBuffer, error) {
|
||||
func (r *ChunkStreamReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
|
||||
size := r.leftOverSize
|
||||
if size == 0 {
|
||||
nextSize, err := r.readSize()
|
||||
@ -86,29 +92,14 @@ func (r *ChunkStreamReader) Read() (buf.MultiBuffer, error) {
|
||||
}
|
||||
size = int(nextSize)
|
||||
}
|
||||
r.leftOverSize = size
|
||||
|
||||
if r.leftOver.IsEmpty() {
|
||||
if err := r.readAtLeast(1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if size >= r.leftOver.Len() {
|
||||
mb := r.leftOver
|
||||
r.leftOverSize = size - r.leftOver.Len()
|
||||
r.leftOver = nil
|
||||
mb, err := r.reader.ReadAtMost(size)
|
||||
if !mb.IsEmpty() {
|
||||
r.leftOverSize -= mb.Len()
|
||||
return mb, nil
|
||||
}
|
||||
|
||||
mb := r.leftOver.SliceBySize(size)
|
||||
if mb.Len() != size {
|
||||
b := buf.New()
|
||||
b.AppendSupplier(buf.ReadFullFrom(&r.leftOver, size-mb.Len()))
|
||||
mb.Append(b)
|
||||
}
|
||||
|
||||
r.leftOverSize = 0
|
||||
return mb, nil
|
||||
return nil, err
|
||||
}
|
||||
|
||||
type ChunkStreamWriter struct {
|
||||
@ -123,18 +114,19 @@ func NewChunkStreamWriter(sizeEncoder ChunkSizeEncoder, writer io.Writer) *Chunk
|
||||
}
|
||||
}
|
||||
|
||||
func (w *ChunkStreamWriter) Write(mb buf.MultiBuffer) error {
|
||||
mb2Write := buf.NewMultiBuffer()
|
||||
func (w *ChunkStreamWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
||||
const sliceSize = 8192
|
||||
mbLen := mb.Len()
|
||||
mb2Write := buf.NewMultiBufferCap(mbLen/buf.Size + mbLen/sliceSize + 2)
|
||||
|
||||
for {
|
||||
slice := mb.SliceBySize(sliceSize)
|
||||
|
||||
b := buf.New()
|
||||
b.AppendSupplier(func(buffer []byte) (int, error) {
|
||||
common.Must(b.Reset(func(buffer []byte) (int, error) {
|
||||
w.sizeEncoder.Encode(uint16(slice.Len()), buffer[:0])
|
||||
return w.sizeEncoder.SizeBytes(), nil
|
||||
})
|
||||
}))
|
||||
mb2Write.Append(b)
|
||||
mb2Write.AppendMulti(slice)
|
||||
|
||||
@ -143,5 +135,5 @@ func (w *ChunkStreamWriter) Write(mb buf.MultiBuffer) error {
|
||||
}
|
||||
}
|
||||
|
||||
return w.writer.Write(mb2Write)
|
||||
return w.writer.WriteMultiBuffer(mb2Write)
|
||||
}
|
||||
|
@ -6,11 +6,11 @@ import (
|
||||
|
||||
"v2ray.com/core/common/buf"
|
||||
. "v2ray.com/core/common/crypto"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestChunkStreamIO(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
cache := buf.NewLocal(8192)
|
||||
|
||||
@ -19,26 +19,26 @@ func TestChunkStreamIO(t *testing.T) {
|
||||
|
||||
b := buf.New()
|
||||
b.AppendBytes('a', 'b', 'c', 'd')
|
||||
assert.Error(writer.Write(buf.NewMultiBufferValue(b))).IsNil()
|
||||
assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)), IsNil)
|
||||
|
||||
b = buf.New()
|
||||
b.AppendBytes('e', 'f', 'g')
|
||||
assert.Error(writer.Write(buf.NewMultiBufferValue(b))).IsNil()
|
||||
assert(writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)), IsNil)
|
||||
|
||||
assert.Error(writer.Write(buf.NewMultiBuffer())).IsNil()
|
||||
assert(writer.WriteMultiBuffer(buf.MultiBuffer{}), IsNil)
|
||||
|
||||
assert.Int(cache.Len()).Equals(13)
|
||||
assert(cache.Len(), Equals, 13)
|
||||
|
||||
mb, err := reader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.Int(mb.Len()).Equals(4)
|
||||
assert.Bytes(mb[0].Bytes()).Equals([]byte("abcd"))
|
||||
mb, err := reader.ReadMultiBuffer()
|
||||
assert(err, IsNil)
|
||||
assert(mb.Len(), Equals, 4)
|
||||
assert(mb[0].Bytes(), Equals, []byte("abcd"))
|
||||
|
||||
mb, err = reader.Read()
|
||||
assert.Error(err).IsNil()
|
||||
assert.Int(mb.Len()).Equals(3)
|
||||
assert.Bytes(mb[0].Bytes()).Equals([]byte("efg"))
|
||||
mb, err = reader.ReadMultiBuffer()
|
||||
assert(err, IsNil)
|
||||
assert(mb.Len(), Equals, 3)
|
||||
assert(mb[0].Bytes(), Equals, []byte("efg"))
|
||||
|
||||
_, err = reader.Read()
|
||||
assert.Error(err).Equals(io.EOF)
|
||||
_, err = reader.ReadMultiBuffer()
|
||||
assert(err, Equals, io.EOF)
|
||||
}
|
||||
|
@ -28,7 +28,7 @@ func (r *CryptionReader) Read(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
var (
|
||||
_ buf.MultiBufferWriter = (*CryptionWriter)(nil)
|
||||
_ buf.Writer = (*CryptionWriter)(nil)
|
||||
)
|
||||
|
||||
type CryptionWriter struct {
|
||||
@ -51,6 +51,8 @@ func (w *CryptionWriter) Write(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
func (w *CryptionWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
||||
defer mb.Release()
|
||||
|
||||
bs := mb.ToNetBuffers()
|
||||
for _, b := range bs {
|
||||
w.stream.XORKeyStream(b, b)
|
||||
|
@ -5,29 +5,29 @@ import (
|
||||
"testing"
|
||||
|
||||
. "v2ray.com/core/common/errors"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
err := New("TestError")
|
||||
assert.Bool(GetSeverity(err) == SeverityInfo).IsTrue()
|
||||
assert(GetSeverity(err), Equals, SeverityInfo)
|
||||
|
||||
err = New("TestError2").Base(io.EOF)
|
||||
assert.Bool(GetSeverity(err) == SeverityInfo).IsTrue()
|
||||
assert(GetSeverity(err), Equals, SeverityInfo)
|
||||
|
||||
err = New("TestError3").Base(io.EOF).AtWarning()
|
||||
assert.Bool(GetSeverity(err) == SeverityWarning).IsTrue()
|
||||
assert(GetSeverity(err), Equals, SeverityWarning)
|
||||
|
||||
err = New("TestError4").Base(io.EOF).AtWarning()
|
||||
err = New("TestError5").Base(err)
|
||||
assert.Bool(GetSeverity(err) == SeverityWarning).IsTrue()
|
||||
assert.String(err.Error()).Contains("EOF")
|
||||
assert(GetSeverity(err), Equals, SeverityWarning)
|
||||
assert(err.Error(), HasSubstring, "EOF")
|
||||
}
|
||||
|
||||
func TestErrorMessage(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
data := []struct {
|
||||
err error
|
||||
@ -44,6 +44,6 @@ func TestErrorMessage(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, d := range data {
|
||||
assert.String(d.err.Error()).Equals(d.msg)
|
||||
assert(d.err.Error(), Equals, d.msg)
|
||||
}
|
||||
}
|
||||
|
46
common/event/event.go
Normal file
46
common/event/event.go
Normal file
@ -0,0 +1,46 @@
|
||||
package event
|
||||
|
||||
import "sync"
|
||||
|
||||
type Event uint16
|
||||
|
||||
type Handler func(data interface{}) error
|
||||
|
||||
type Registry interface {
|
||||
On(Event, Handler)
|
||||
}
|
||||
|
||||
type Listener struct {
|
||||
sync.RWMutex
|
||||
events map[Event][]Handler
|
||||
}
|
||||
|
||||
func (l *Listener) On(e Event, h Handler) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if l.events == nil {
|
||||
l.events = make(map[Event][]Handler)
|
||||
}
|
||||
|
||||
handlers := l.events[e]
|
||||
handlers = append(handlers, h)
|
||||
l.events[e] = handlers
|
||||
}
|
||||
|
||||
func (l *Listener) Fire(e Event, data interface{}) error {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
|
||||
if l.events == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, h := range l.events[e] {
|
||||
if err := h(data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -73,6 +73,12 @@ type Address interface {
|
||||
// ParseAddress parses a string into an Address. The return value will be an IPAddress when
|
||||
// the string is in the form of IPv4 or IPv6 address, or a DomainAddress otherwise.
|
||||
func ParseAddress(addr string) Address {
|
||||
// Handle IPv6 address in form as "[2001:4860:0:2001::68]"
|
||||
lenAddr := len(addr)
|
||||
if lenAddr > 0 && addr[0] == '[' && addr[lenAddr-1] == ']' {
|
||||
addr = addr[1 : lenAddr-1]
|
||||
}
|
||||
|
||||
ip := net.ParseIP(addr)
|
||||
if ip != nil {
|
||||
return IPAddress(ip)
|
||||
|
@ -5,24 +5,25 @@ import (
|
||||
"testing"
|
||||
|
||||
. "v2ray.com/core/common/net"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/core/common/net/testing"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestIPv4Address(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
ip := []byte{byte(1), byte(2), byte(3), byte(4)}
|
||||
addr := IPAddress(ip)
|
||||
|
||||
assert.Address(addr).IsIPv4()
|
||||
assert.Address(addr).IsNotIPv6()
|
||||
assert.Address(addr).IsNotDomain()
|
||||
assert.Bytes(addr.IP()).Equals(ip)
|
||||
assert.Address(addr).EqualsString("1.2.3.4")
|
||||
assert(addr, IsIPv4)
|
||||
assert(addr, Not(IsIPv6))
|
||||
assert(addr, Not(IsDomain))
|
||||
assert([]byte(addr.IP()), Equals, ip)
|
||||
assert(addr.String(), Equals, "1.2.3.4")
|
||||
}
|
||||
|
||||
func TestIPv6Address(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
ip := []byte{
|
||||
byte(1), byte(2), byte(3), byte(4),
|
||||
@ -32,15 +33,15 @@ func TestIPv6Address(t *testing.T) {
|
||||
}
|
||||
addr := IPAddress(ip)
|
||||
|
||||
assert.Address(addr).IsIPv6()
|
||||
assert.Address(addr).IsNotIPv4()
|
||||
assert.Address(addr).IsNotDomain()
|
||||
assert.IP(addr.IP()).Equals(net.IP(ip))
|
||||
assert.Address(addr).EqualsString("[102:304:102:304:102:304:102:304]")
|
||||
assert(addr, IsIPv6)
|
||||
assert(addr, Not(IsIPv4))
|
||||
assert(addr, Not(IsDomain))
|
||||
assert(addr.IP(), Equals, net.IP(ip))
|
||||
assert(addr.String(), Equals, "[102:304:102:304:102:304:102:304]")
|
||||
}
|
||||
|
||||
func TestIPv4Asv6(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
ip := []byte{
|
||||
byte(0), byte(0), byte(0), byte(0),
|
||||
byte(0), byte(0), byte(0), byte(0),
|
||||
@ -48,27 +49,55 @@ func TestIPv4Asv6(t *testing.T) {
|
||||
byte(1), byte(2), byte(3), byte(4),
|
||||
}
|
||||
addr := IPAddress(ip)
|
||||
assert.Address(addr).EqualsString("1.2.3.4")
|
||||
assert(addr.String(), Equals, "1.2.3.4")
|
||||
}
|
||||
|
||||
func TestDomainAddress(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
domain := "v2ray.com"
|
||||
addr := DomainAddress(domain)
|
||||
|
||||
assert.Address(addr).IsDomain()
|
||||
assert.Address(addr).IsNotIPv6()
|
||||
assert.Address(addr).IsNotIPv4()
|
||||
assert.String(addr.Domain()).Equals(domain)
|
||||
assert.Address(addr).EqualsString("v2ray.com")
|
||||
assert(addr, IsDomain)
|
||||
assert(addr, Not(IsIPv6))
|
||||
assert(addr, Not(IsIPv4))
|
||||
assert(addr.Domain(), Equals, domain)
|
||||
assert(addr.String(), Equals, "v2ray.com")
|
||||
}
|
||||
|
||||
func TestNetIPv4Address(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
ip := net.IPv4(1, 2, 3, 4)
|
||||
addr := IPAddress(ip)
|
||||
assert.Address(addr).IsIPv4()
|
||||
assert.Address(addr).EqualsString("1.2.3.4")
|
||||
assert(addr, IsIPv4)
|
||||
assert(addr.String(), Equals, "1.2.3.4")
|
||||
}
|
||||
|
||||
func TestParseIPv6Address(t *testing.T) {
|
||||
assert := With(t)
|
||||
|
||||
ip := ParseAddress("[2001:4860:0:2001::68]")
|
||||
assert(ip, IsIPv6)
|
||||
assert(ip.String(), Equals, "[2001:4860:0:2001::68]")
|
||||
|
||||
ip = ParseAddress("[::ffff:123.151.71.143]")
|
||||
assert(ip, IsIPv4)
|
||||
assert(ip.String(), Equals, "123.151.71.143")
|
||||
}
|
||||
|
||||
func TestInvalidAddressConvertion(t *testing.T) {
|
||||
assert := With(t)
|
||||
|
||||
assert(func() { ParseAddress("8.8.8.8").Domain() }, Panics)
|
||||
assert(func() { ParseAddress("2001:4860:0:2001::68").Domain() }, Panics)
|
||||
assert(func() { ParseAddress("v2ray.com").IP() }, Panics)
|
||||
}
|
||||
|
||||
func TestIPOrDomain(t *testing.T) {
|
||||
assert := With(t)
|
||||
|
||||
assert(NewIPOrDomain(ParseAddress("v2ray.com")).AsAddress(), Equals, ParseAddress("v2ray.com"))
|
||||
assert(NewIPOrDomain(ParseAddress("8.8.8.8")).AsAddress(), Equals, ParseAddress("8.8.8.8"))
|
||||
assert(NewIPOrDomain(ParseAddress("2001:4860:0:2001::68")).AsAddress(), Equals, ParseAddress("2001:4860:0:2001::68"))
|
||||
}
|
||||
|
@ -41,14 +41,17 @@ func UDPDestination(address Address, port Port) Destination {
|
||||
}
|
||||
}
|
||||
|
||||
// NetAddr returns the network address in this Destination in string form.
|
||||
func (d Destination) NetAddr() string {
|
||||
return d.Address.String() + ":" + d.Port.String()
|
||||
}
|
||||
|
||||
// String returns the strings form of this Destination.
|
||||
func (d Destination) String() string {
|
||||
return d.Network.URLPrefix() + ":" + d.NetAddr()
|
||||
}
|
||||
|
||||
// IsValid returns true if this Destination is valid.
|
||||
func (d Destination) IsValid() bool {
|
||||
return d.Network != Network_Unknown
|
||||
}
|
||||
|
@ -4,23 +4,24 @@ import (
|
||||
"testing"
|
||||
|
||||
. "v2ray.com/core/common/net"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/core/common/net/testing"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestTCPDestination(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
dest := TCPDestination(IPAddress([]byte{1, 2, 3, 4}), 80)
|
||||
assert.Destination(dest).IsTCP()
|
||||
assert.Destination(dest).IsNotUDP()
|
||||
assert.Destination(dest).EqualsString("tcp:1.2.3.4:80")
|
||||
assert(dest, IsTCP)
|
||||
assert(dest, Not(IsUDP))
|
||||
assert(dest.String(), Equals, "tcp:1.2.3.4:80")
|
||||
}
|
||||
|
||||
func TestUDPDestination(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
dest := UDPDestination(IPAddress([]byte{0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x88, 0x88}), 53)
|
||||
assert.Destination(dest).IsNotTCP()
|
||||
assert.Destination(dest).IsUDP()
|
||||
assert.Destination(dest).EqualsString("udp:[2001:4860:4860::8888]:53")
|
||||
assert(dest, Not(IsTCP))
|
||||
assert(dest, IsUDP)
|
||||
assert(dest.String(), Equals, "udp:[2001:4860:4860::8888]:53")
|
||||
}
|
||||
|
@ -44,6 +44,7 @@ func (n *IPNetTable) Add(ipNet *net.IPNet) {
|
||||
|
||||
func (n *IPNetTable) AddIP(ip []byte, mask byte) {
|
||||
k := ipToUint32(ip)
|
||||
k = (k >> (32 - mask)) << (32 - mask) // normalize ip
|
||||
existing, found := n.cache[k]
|
||||
if !found || existing > mask {
|
||||
n.cache[k] = mask
|
||||
|
@ -2,11 +2,19 @@ package net_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
"v2ray.com/core/app/router"
|
||||
"v2ray.com/core/common/platform"
|
||||
|
||||
"v2ray.com/ext/sysio"
|
||||
|
||||
"v2ray.com/core/common"
|
||||
. "v2ray.com/core/common/net"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func parseCIDR(str string) *net.IPNet {
|
||||
@ -16,7 +24,7 @@ func parseCIDR(str string) *net.IPNet {
|
||||
}
|
||||
|
||||
func TestIPNet(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
ipNet := NewIPNetTable()
|
||||
ipNet.Add(parseCIDR(("0.0.0.0/8")))
|
||||
@ -32,12 +40,94 @@ func TestIPNet(t *testing.T) {
|
||||
ipNet.Add(parseCIDR(("198.51.100.0/24")))
|
||||
ipNet.Add(parseCIDR(("203.0.113.0/24")))
|
||||
ipNet.Add(parseCIDR(("8.8.8.8/32")))
|
||||
assert.Bool(ipNet.Contains(ParseIP("192.168.1.1"))).IsTrue()
|
||||
assert.Bool(ipNet.Contains(ParseIP("192.0.0.0"))).IsTrue()
|
||||
assert.Bool(ipNet.Contains(ParseIP("192.0.1.0"))).IsFalse()
|
||||
assert.Bool(ipNet.Contains(ParseIP("0.1.0.0"))).IsTrue()
|
||||
assert.Bool(ipNet.Contains(ParseIP("1.0.0.1"))).IsFalse()
|
||||
assert.Bool(ipNet.Contains(ParseIP("8.8.8.7"))).IsFalse()
|
||||
assert.Bool(ipNet.Contains(ParseIP("8.8.8.8"))).IsTrue()
|
||||
assert.Bool(ipNet.Contains(ParseIP("2001:cdba::3257:9652"))).IsFalse()
|
||||
ipNet.AddIP(net.ParseIP("91.108.4.0"), 16)
|
||||
assert(ipNet.Contains(ParseIP("192.168.1.1")), IsTrue)
|
||||
assert(ipNet.Contains(ParseIP("192.0.0.0")), IsTrue)
|
||||
assert(ipNet.Contains(ParseIP("192.0.1.0")), IsFalse)
|
||||
assert(ipNet.Contains(ParseIP("0.1.0.0")), IsTrue)
|
||||
assert(ipNet.Contains(ParseIP("1.0.0.1")), IsFalse)
|
||||
assert(ipNet.Contains(ParseIP("8.8.8.7")), IsFalse)
|
||||
assert(ipNet.Contains(ParseIP("8.8.8.8")), IsTrue)
|
||||
assert(ipNet.Contains(ParseIP("2001:cdba::3257:9652")), IsFalse)
|
||||
assert(ipNet.Contains(ParseIP("91.108.255.254")), IsTrue)
|
||||
}
|
||||
|
||||
func TestGeoIPCN(t *testing.T) {
|
||||
assert := With(t)
|
||||
common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "tools", "release", "config", "geoip.dat")))
|
||||
|
||||
ips, err := loadGeoIP("CN")
|
||||
common.Must(err)
|
||||
|
||||
ipNet := NewIPNetTable()
|
||||
for _, ip := range ips {
|
||||
ipNet.AddIP(ip.Ip, byte(ip.Prefix))
|
||||
}
|
||||
|
||||
assert(ipNet.Contains([]byte{8, 8, 8, 8}), IsFalse)
|
||||
}
|
||||
|
||||
func loadGeoIP(country string) ([]*router.CIDR, error) {
|
||||
geoipBytes, err := sysio.ReadAsset("geoip.dat")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var geoipList router.GeoIPList
|
||||
if err := proto.Unmarshal(geoipBytes, &geoipList); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, geoip := range geoipList.Entry {
|
||||
if geoip.CountryCode == country {
|
||||
return geoip.Cidr, nil
|
||||
}
|
||||
}
|
||||
|
||||
panic("country not found: " + country)
|
||||
}
|
||||
|
||||
func BenchmarkIPNetQuery(b *testing.B) {
|
||||
common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "tools", "release", "config", "geoip.dat")))
|
||||
|
||||
ips, err := loadGeoIP("CN")
|
||||
common.Must(err)
|
||||
|
||||
ipNet := NewIPNetTable()
|
||||
for _, ip := range ips {
|
||||
ipNet.AddIP(ip.Ip, byte(ip.Prefix))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
ipNet.Contains([]byte{8, 8, 8, 8})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCIDRQuery(b *testing.B) {
|
||||
common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "tools", "release", "config", "geoip.dat")))
|
||||
|
||||
ips, err := loadGeoIP("CN")
|
||||
common.Must(err)
|
||||
|
||||
ipNet := make([]*net.IPNet, 0, 1024)
|
||||
for _, ip := range ips {
|
||||
if len(ip.Ip) != 4 {
|
||||
continue
|
||||
}
|
||||
ipNet = append(ipNet, &net.IPNet{
|
||||
IP: net.IP(ip.Ip),
|
||||
Mask: net.CIDRMask(int(ip.Prefix), 32),
|
||||
})
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, n := range ipNet {
|
||||
if n.Contains([]byte{8, 8, 8, 8}) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Package net contains common network utilities.
|
||||
// Package net is a drop-in replacement to Golang's net package, with some more functionalities.
|
||||
package net
|
||||
|
||||
//go:generate go run $GOPATH/src/v2ray.com/core/tools/generrorgen/main.go -pkg net -path Net
|
||||
|
@ -4,15 +4,15 @@ import (
|
||||
"testing"
|
||||
|
||||
. "v2ray.com/core/common/net"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestPortRangeContains(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
portRange := &PortRange{
|
||||
From: 53,
|
||||
To: 53,
|
||||
}
|
||||
assert.Bool(portRange.Contains(Port(53))).IsTrue()
|
||||
assert(portRange.Contains(Port(53)), IsTrue)
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package net
|
||||
|
||||
import "net"
|
||||
|
||||
// DialTCP is an injectable function. Default to net.DialTCP
|
||||
var DialTCP = net.DialTCP
|
||||
var DialUDP = net.DialUDP
|
||||
var DialUnix = net.DialUnix
|
||||
@ -31,6 +32,7 @@ type UDPConn = net.UDPConn
|
||||
type UnixAddr = net.UnixAddr
|
||||
type UnixConn = net.UnixConn
|
||||
|
||||
// IP is an alias for net.IP.
|
||||
type IP = net.IP
|
||||
type IPMask = net.IPMask
|
||||
type IPNet = net.IPNet
|
||||
|
48
common/net/testing/assert.go
Normal file
48
common/net/testing/assert.go
Normal file
@ -0,0 +1,48 @@
|
||||
package testing
|
||||
|
||||
import (
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
var IsIPv4 = assert.CreateMatcher(func(a net.Address) bool {
|
||||
return a.Family().IsIPv4()
|
||||
}, "is IPv4")
|
||||
|
||||
var IsIPv6 = assert.CreateMatcher(func(a net.Address) bool {
|
||||
return a.Family().IsIPv6()
|
||||
}, "is IPv6")
|
||||
|
||||
var IsIP = assert.CreateMatcher(func(a net.Address) bool {
|
||||
return a.Family().IsIPv4() || a.Family().IsIPv6()
|
||||
}, "is IP")
|
||||
|
||||
var IsTCP = assert.CreateMatcher(func(a net.Destination) bool {
|
||||
return a.Network == net.Network_TCP
|
||||
}, "is TCP")
|
||||
|
||||
var IsUDP = assert.CreateMatcher(func(a net.Destination) bool {
|
||||
return a.Network == net.Network_UDP
|
||||
}, "is UDP")
|
||||
|
||||
var IsDomain = assert.CreateMatcher(func(a net.Address) bool {
|
||||
return a.Family().IsDomain()
|
||||
}, "is Domain")
|
||||
|
||||
func init() {
|
||||
assert.RegisterEqualsMatcher(func(a, b net.Address) bool {
|
||||
return a == b
|
||||
})
|
||||
|
||||
assert.RegisterEqualsMatcher(func(a, b net.Destination) bool {
|
||||
return a == b
|
||||
})
|
||||
|
||||
assert.RegisterEqualsMatcher(func(a, b net.Port) bool {
|
||||
return a == b
|
||||
})
|
||||
|
||||
assert.RegisterEqualsMatcher(func(a, b net.IP) bool {
|
||||
return a.Equal(b)
|
||||
})
|
||||
}
|
@ -4,6 +4,7 @@ package platform
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func ExpandEnv(s string) string {
|
||||
@ -13,3 +14,9 @@ func ExpandEnv(s string) string {
|
||||
func LineSeparator() string {
|
||||
return "\n"
|
||||
}
|
||||
|
||||
func GetToolLocation(file string) string {
|
||||
const name = "v2ray.location.tool"
|
||||
toolPath := EnvFlag{Name: name, AltName: NormalizeEnvName(name)}.GetValue(getExecutableDir)
|
||||
return filepath.Join(toolPath, file)
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package platform
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
@ -11,7 +12,7 @@ type EnvFlag struct {
|
||||
AltName string
|
||||
}
|
||||
|
||||
func (f EnvFlag) GetValue(defaultValue string) string {
|
||||
func (f EnvFlag) GetValue(defaultValue func() string) string {
|
||||
if v, found := os.LookupEnv(f.Name); found {
|
||||
return v
|
||||
}
|
||||
@ -21,13 +22,16 @@ func (f EnvFlag) GetValue(defaultValue string) string {
|
||||
}
|
||||
}
|
||||
|
||||
return defaultValue
|
||||
return defaultValue()
|
||||
}
|
||||
|
||||
func (f EnvFlag) GetValueAsInt(defaultValue int) int {
|
||||
const PlaceHolder = "xxxxxx"
|
||||
s := f.GetValue(PlaceHolder)
|
||||
if s == PlaceHolder {
|
||||
useDefaultValue := false
|
||||
s := f.GetValue(func() string {
|
||||
useDefaultValue = true
|
||||
return ""
|
||||
})
|
||||
if useDefaultValue {
|
||||
return defaultValue
|
||||
}
|
||||
v, err := strconv.ParseInt(s, 10, 32)
|
||||
@ -40,3 +44,29 @@ func (f EnvFlag) GetValueAsInt(defaultValue int) int {
|
||||
func NormalizeEnvName(name string) string {
|
||||
return strings.Replace(strings.ToUpper(strings.TrimSpace(name)), ".", "_", -1)
|
||||
}
|
||||
|
||||
func getExecutableDir() string {
|
||||
exec, err := os.Executable()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return filepath.Dir(exec)
|
||||
}
|
||||
|
||||
func getExecuableSubDir(dir string) func() string {
|
||||
return func() string {
|
||||
return filepath.Join(getExecutableDir(), dir)
|
||||
}
|
||||
}
|
||||
|
||||
func GetAssetLocation(file string) string {
|
||||
const name = "v2ray.location.asset"
|
||||
assetPath := EnvFlag{Name: name, AltName: NormalizeEnvName(name)}.GetValue(getExecutableDir)
|
||||
return filepath.Join(assetPath, file)
|
||||
}
|
||||
|
||||
func GetPluginDirectory() string {
|
||||
const name = "v2ray.location.plugin"
|
||||
pluginDir := EnvFlag{Name: name, AltName: NormalizeEnvName(name)}.GetValue(getExecuableSubDir("plugins"))
|
||||
return pluginDir
|
||||
}
|
||||
|
@ -1,14 +1,16 @@
|
||||
package platform_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
. "v2ray.com/core/common/platform"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestNormalizeEnvName(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
cases := []struct {
|
||||
input string
|
||||
@ -28,14 +30,27 @@ func TestNormalizeEnvName(t *testing.T) {
|
||||
},
|
||||
}
|
||||
for _, test := range cases {
|
||||
assert.String(NormalizeEnvName(test.input)).Equals(test.output)
|
||||
assert(NormalizeEnvName(test.input), Equals, test.output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvFlag(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
assert.Int(EnvFlag{
|
||||
assert(EnvFlag{
|
||||
Name: "xxxxx.y",
|
||||
}.GetValueAsInt(10)).Equals(10)
|
||||
}.GetValueAsInt(10), Equals, 10)
|
||||
}
|
||||
|
||||
func TestGetAssetLocation(t *testing.T) {
|
||||
assert := With(t)
|
||||
|
||||
exec, err := os.Executable()
|
||||
assert(err, IsNil)
|
||||
|
||||
loc := GetAssetLocation("t")
|
||||
assert(filepath.Dir(loc), Equals, filepath.Dir(exec))
|
||||
|
||||
os.Setenv("v2ray.location.asset", "/v2ray")
|
||||
assert(GetAssetLocation("t"), Equals, "/v2ray/t")
|
||||
}
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
package platform
|
||||
|
||||
import "path/filepath"
|
||||
|
||||
func ExpandEnv(s string) string {
|
||||
// TODO
|
||||
return s
|
||||
@ -10,3 +12,9 @@ func ExpandEnv(s string) string {
|
||||
func LineSeparator() string {
|
||||
return "\r\n"
|
||||
}
|
||||
|
||||
func GetToolLocation(file string) string {
|
||||
const name = "v2ray.location.tool"
|
||||
toolPath := EnvFlag{Name: name, AltName: NormalizeEnvName(name)}.GetValue(getExecutableDir)
|
||||
return filepath.Join(toolPath, file+".exe")
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package protocol
|
||||
import (
|
||||
"runtime"
|
||||
|
||||
"v2ray.com/core/common/bitmask"
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/uuid"
|
||||
)
|
||||
@ -24,35 +25,20 @@ func (c RequestCommand) TransferType() TransferType {
|
||||
return TransferTypePacket
|
||||
}
|
||||
|
||||
// RequestOption is the options of a request.
|
||||
type RequestOption byte
|
||||
|
||||
const (
|
||||
// RequestOptionChunkStream indicates request payload is chunked. Each chunk consists of length, authentication and payload.
|
||||
RequestOptionChunkStream = RequestOption(0x01)
|
||||
RequestOptionChunkStream bitmask.Byte = 0x01
|
||||
|
||||
// RequestOptionConnectionReuse indicates client side expects to reuse the connection.
|
||||
RequestOptionConnectionReuse = RequestOption(0x02)
|
||||
RequestOptionConnectionReuse bitmask.Byte = 0x02
|
||||
|
||||
RequestOptionChunkMasking = RequestOption(0x04)
|
||||
RequestOptionChunkMasking bitmask.Byte = 0x04
|
||||
)
|
||||
|
||||
func (v RequestOption) Has(option RequestOption) bool {
|
||||
return (v & option) == option
|
||||
}
|
||||
|
||||
func (v *RequestOption) Set(option RequestOption) {
|
||||
*v = (*v | option)
|
||||
}
|
||||
|
||||
func (v *RequestOption) Clear(option RequestOption) {
|
||||
*v = (*v & (^option))
|
||||
}
|
||||
|
||||
type Security byte
|
||||
|
||||
func (v Security) Is(t SecurityType) bool {
|
||||
return v == Security(t)
|
||||
func (s Security) Is(t SecurityType) bool {
|
||||
return s == Security(t)
|
||||
}
|
||||
|
||||
func NormSecurity(s Security) Security {
|
||||
@ -65,42 +51,28 @@ func NormSecurity(s Security) Security {
|
||||
type RequestHeader struct {
|
||||
Version byte
|
||||
Command RequestCommand
|
||||
Option RequestOption
|
||||
Option bitmask.Byte
|
||||
Security Security
|
||||
Port net.Port
|
||||
Address net.Address
|
||||
User *User
|
||||
}
|
||||
|
||||
func (v *RequestHeader) Destination() net.Destination {
|
||||
if v.Command == RequestCommandUDP {
|
||||
return net.UDPDestination(v.Address, v.Port)
|
||||
func (h *RequestHeader) Destination() net.Destination {
|
||||
if h.Command == RequestCommandUDP {
|
||||
return net.UDPDestination(h.Address, h.Port)
|
||||
}
|
||||
return net.TCPDestination(v.Address, v.Port)
|
||||
return net.TCPDestination(h.Address, h.Port)
|
||||
}
|
||||
|
||||
type ResponseOption byte
|
||||
|
||||
const (
|
||||
ResponseOptionConnectionReuse = ResponseOption(0x01)
|
||||
ResponseOptionConnectionReuse bitmask.Byte = 0x01
|
||||
)
|
||||
|
||||
func (v *ResponseOption) Set(option ResponseOption) {
|
||||
*v = (*v | option)
|
||||
}
|
||||
|
||||
func (v ResponseOption) Has(option ResponseOption) bool {
|
||||
return (v & option) == option
|
||||
}
|
||||
|
||||
func (v *ResponseOption) Clear(option ResponseOption) {
|
||||
*v = (*v & (^option))
|
||||
}
|
||||
|
||||
type ResponseCommand interface{}
|
||||
|
||||
type ResponseHeader struct {
|
||||
Option ResponseOption
|
||||
Option bitmask.Byte
|
||||
Command ResponseCommand
|
||||
}
|
||||
|
||||
@ -108,20 +80,21 @@ type CommandSwitchAccount struct {
|
||||
Host net.Address
|
||||
Port net.Port
|
||||
ID *uuid.UUID
|
||||
AlterIds uint16
|
||||
Level uint32
|
||||
AlterIds uint16
|
||||
ValidMin byte
|
||||
}
|
||||
|
||||
func (v *SecurityConfig) AsSecurity() Security {
|
||||
if v == nil {
|
||||
return Security(SecurityType_LEGACY)
|
||||
}
|
||||
if v.Type == SecurityType_AUTO {
|
||||
func (sc *SecurityConfig) AsSecurity() Security {
|
||||
if sc == nil || sc.Type == SecurityType_AUTO {
|
||||
if runtime.GOARCH == "amd64" || runtime.GOARCH == "s390x" {
|
||||
return Security(SecurityType_AES128_GCM)
|
||||
}
|
||||
return Security(SecurityType_CHACHA20_POLY1305)
|
||||
}
|
||||
return NormSecurity(Security(v.Type))
|
||||
return NormSecurity(Security(sc.Type))
|
||||
}
|
||||
|
||||
func IsDomainTooLong(domain string) bool {
|
||||
return len(domain) > 256
|
||||
}
|
||||
|
@ -1,34 +0,0 @@
|
||||
package protocol_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "v2ray.com/core/common/protocol"
|
||||
"v2ray.com/core/testing/assert"
|
||||
)
|
||||
|
||||
func TestRequestOptionSet(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
|
||||
var option RequestOption
|
||||
assert.Bool(option.Has(RequestOptionChunkStream)).IsFalse()
|
||||
|
||||
option.Set(RequestOptionChunkStream)
|
||||
assert.Bool(option.Has(RequestOptionChunkStream)).IsTrue()
|
||||
|
||||
option.Set(RequestOptionChunkMasking)
|
||||
assert.Bool(option.Has(RequestOptionChunkMasking)).IsTrue()
|
||||
assert.Bool(option.Has(RequestOptionChunkStream)).IsTrue()
|
||||
}
|
||||
|
||||
func TestRequestOptionClear(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
|
||||
var option RequestOption
|
||||
option.Set(RequestOptionChunkStream)
|
||||
option.Set(RequestOptionChunkMasking)
|
||||
|
||||
option.Clear(RequestOptionChunkStream)
|
||||
assert.Bool(option.Has(RequestOptionChunkStream)).IsFalse()
|
||||
assert.Bool(option.Has(RequestOptionChunkMasking)).IsTrue()
|
||||
}
|
@ -6,12 +6,12 @@ import (
|
||||
"v2ray.com/core/common/predicate"
|
||||
. "v2ray.com/core/common/protocol"
|
||||
"v2ray.com/core/common/uuid"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestCmdKey(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
id := NewID(uuid.New())
|
||||
assert.Bool(predicate.BytesAll(id.CmdKey(), 0)).IsFalse()
|
||||
assert(predicate.BytesAll(id.CmdKey(), 0), IsFalse)
|
||||
}
|
||||
|
@ -6,3 +6,11 @@ const (
|
||||
TransferTypeStream TransferType = 0
|
||||
TransferTypePacket TransferType = 1
|
||||
)
|
||||
|
||||
type AddressType byte
|
||||
|
||||
const (
|
||||
AddressTypeIPv4 AddressType = 1
|
||||
AddressTypeDomain AddressType = 2
|
||||
AddressTypeIPv6 AddressType = 3
|
||||
)
|
||||
|
@ -6,30 +6,30 @@ import (
|
||||
|
||||
"v2ray.com/core/common/net"
|
||||
. "v2ray.com/core/common/protocol"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestServerList(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
list := NewServerList()
|
||||
list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(1)), AlwaysValid()))
|
||||
assert.Uint32(list.Size()).Equals(1)
|
||||
assert(list.Size(), Equals, uint32(1))
|
||||
list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(2)), BeforeTime(time.Now().Add(time.Second))))
|
||||
assert.Uint32(list.Size()).Equals(2)
|
||||
assert(list.Size(), Equals, uint32(2))
|
||||
|
||||
server := list.GetServer(1)
|
||||
assert.Port(server.Destination().Port).Equals(2)
|
||||
assert(server.Destination().Port, Equals, net.Port(2))
|
||||
time.Sleep(2 * time.Second)
|
||||
server = list.GetServer(1)
|
||||
assert.Pointer(server).IsNil()
|
||||
assert(server, IsNil)
|
||||
|
||||
server = list.GetServer(0)
|
||||
assert.Port(server.Destination().Port).Equals(1)
|
||||
assert(server.Destination().Port, Equals, net.Port(1))
|
||||
}
|
||||
|
||||
func TestServerPicker(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
list := NewServerList()
|
||||
list.AddServer(NewServerSpec(net.TCPDestination(net.LocalHostIP, net.Port(1)), AlwaysValid()))
|
||||
@ -38,17 +38,17 @@ func TestServerPicker(t *testing.T) {
|
||||
|
||||
picker := NewRoundRobinServerPicker(list)
|
||||
server := picker.PickServer()
|
||||
assert.Port(server.Destination().Port).Equals(1)
|
||||
assert(server.Destination().Port, Equals, net.Port(1))
|
||||
server = picker.PickServer()
|
||||
assert.Port(server.Destination().Port).Equals(2)
|
||||
assert(server.Destination().Port, Equals, net.Port(2))
|
||||
server = picker.PickServer()
|
||||
assert.Port(server.Destination().Port).Equals(3)
|
||||
assert(server.Destination().Port, Equals, net.Port(3))
|
||||
server = picker.PickServer()
|
||||
assert.Port(server.Destination().Port).Equals(1)
|
||||
assert(server.Destination().Port, Equals, net.Port(1))
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
server = picker.PickServer()
|
||||
assert.Port(server.Destination().Port).Equals(1)
|
||||
assert(server.Destination().Port, Equals, net.Port(1))
|
||||
server = picker.PickServer()
|
||||
assert.Port(server.Destination().Port).Equals(1)
|
||||
assert(server.Destination().Port, Equals, net.Port(1))
|
||||
}
|
||||
|
@ -5,27 +5,27 @@ import (
|
||||
"time"
|
||||
|
||||
. "v2ray.com/core/common/protocol"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestAlwaysValidStrategy(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
strategy := AlwaysValid()
|
||||
assert.Bool(strategy.IsValid()).IsTrue()
|
||||
assert(strategy.IsValid(), IsTrue)
|
||||
strategy.Invalidate()
|
||||
assert.Bool(strategy.IsValid()).IsTrue()
|
||||
assert(strategy.IsValid(), IsTrue)
|
||||
}
|
||||
|
||||
func TestTimeoutValidStrategy(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
strategy := BeforeTime(time.Now().Add(2 * time.Second))
|
||||
assert.Bool(strategy.IsValid()).IsTrue()
|
||||
assert(strategy.IsValid(), IsTrue)
|
||||
time.Sleep(3 * time.Second)
|
||||
assert.Bool(strategy.IsValid()).IsFalse()
|
||||
assert(strategy.IsValid(), IsFalse)
|
||||
|
||||
strategy = BeforeTime(time.Now().Add(2 * time.Second))
|
||||
strategy.Invalidate()
|
||||
assert.Bool(strategy.IsValid()).IsFalse()
|
||||
assert(strategy.IsValid(), IsFalse)
|
||||
}
|
||||
|
@ -5,11 +5,11 @@ import (
|
||||
"time"
|
||||
|
||||
. "v2ray.com/core/common/protocol"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestGenerateRandomInt64InRange(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
base := time.Now().Unix()
|
||||
delta := 100
|
||||
@ -17,7 +17,7 @@ func TestGenerateRandomInt64InRange(t *testing.T) {
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
val := int64(generator())
|
||||
assert.Int64(val).AtMost(base + int64(delta))
|
||||
assert.Int64(val).AtLeast(base - int64(delta))
|
||||
assert(val, AtMost, base+int64(delta))
|
||||
assert(val, AtLeast, base-int64(delta))
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,5 @@
|
||||
package protocol
|
||||
|
||||
import "time"
|
||||
|
||||
func (u *User) GetTypedAccount() (Account, error) {
|
||||
if u.GetAccount() == nil {
|
||||
return nil, newError("Account missing").AtWarning()
|
||||
@ -19,20 +17,3 @@ func (u *User) GetTypedAccount() (Account, error) {
|
||||
}
|
||||
return nil, newError("Unknown account type: ", u.Account.Type)
|
||||
}
|
||||
|
||||
func (u *User) GetSettings() UserSettings {
|
||||
settings := UserSettings{}
|
||||
switch u.Level {
|
||||
case 0:
|
||||
settings.PayloadTimeout = time.Second * 30
|
||||
case 1:
|
||||
settings.PayloadTimeout = time.Minute * 2
|
||||
default:
|
||||
settings.PayloadTimeout = time.Minute * 5
|
||||
}
|
||||
return settings
|
||||
}
|
||||
|
||||
type UserSettings struct {
|
||||
PayloadTimeout time.Duration
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ import (
|
||||
|
||||
"v2ray.com/core/common/errors"
|
||||
. "v2ray.com/core/common/retry"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -14,7 +14,7 @@ var (
|
||||
)
|
||||
|
||||
func TestNoRetry(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
startTime := time.Now().Unix()
|
||||
err := Timed(10, 100000).On(func() error {
|
||||
@ -22,12 +22,12 @@ func TestNoRetry(t *testing.T) {
|
||||
})
|
||||
endTime := time.Now().Unix()
|
||||
|
||||
assert.Error(err).IsNil()
|
||||
assert.Int64(endTime - startTime).AtLeast(0)
|
||||
assert(err, IsNil)
|
||||
assert(endTime-startTime, AtLeast, int64(0))
|
||||
}
|
||||
|
||||
func TestRetryOnce(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
startTime := time.Now()
|
||||
called := 0
|
||||
@ -40,12 +40,12 @@ func TestRetryOnce(t *testing.T) {
|
||||
})
|
||||
duration := time.Since(startTime)
|
||||
|
||||
assert.Error(err).IsNil()
|
||||
assert.Int64(int64(duration / time.Millisecond)).AtLeast(900)
|
||||
assert(err, IsNil)
|
||||
assert(int64(duration/time.Millisecond), AtLeast, int64(900))
|
||||
}
|
||||
|
||||
func TestRetryMultiple(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
startTime := time.Now()
|
||||
called := 0
|
||||
@ -58,12 +58,12 @@ func TestRetryMultiple(t *testing.T) {
|
||||
})
|
||||
duration := time.Since(startTime)
|
||||
|
||||
assert.Error(err).IsNil()
|
||||
assert.Int64(int64(duration / time.Millisecond)).AtLeast(4900)
|
||||
assert(err, IsNil)
|
||||
assert(int64(duration/time.Millisecond), AtLeast, int64(4900))
|
||||
}
|
||||
|
||||
func TestRetryExhausted(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
startTime := time.Now()
|
||||
called := 0
|
||||
@ -73,12 +73,12 @@ func TestRetryExhausted(t *testing.T) {
|
||||
})
|
||||
duration := time.Since(startTime)
|
||||
|
||||
assert.Error(errors.Cause(err)).Equals(ErrRetryFailed)
|
||||
assert.Int64(int64(duration / time.Millisecond)).AtLeast(1900)
|
||||
assert(errors.Cause(err), Equals, ErrRetryFailed)
|
||||
assert(int64(duration/time.Millisecond), AtLeast, int64(1900))
|
||||
}
|
||||
|
||||
func TestExponentialBackoff(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
startTime := time.Now()
|
||||
called := 0
|
||||
@ -88,6 +88,6 @@ func TestExponentialBackoff(t *testing.T) {
|
||||
})
|
||||
duration := time.Since(startTime)
|
||||
|
||||
assert.Error(errors.Cause(err)).Equals(ErrRetryFailed)
|
||||
assert.Int64(int64(duration / time.Millisecond)).AtLeast(4000)
|
||||
assert(errors.Cause(err), Equals, ErrRetryFailed)
|
||||
assert(int64(duration/time.Millisecond), AtLeast, int64(4000))
|
||||
}
|
||||
|
@ -4,11 +4,11 @@ import (
|
||||
"testing"
|
||||
|
||||
. "v2ray.com/core/common/serial"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestBytesToHex(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
cases := []struct {
|
||||
input []byte
|
||||
@ -21,15 +21,15 @@ func TestBytesToHex(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range cases {
|
||||
assert.String(test.output).Equals(BytesToHexString(test.input))
|
||||
assert(test.output, Equals, BytesToHexString(test.input))
|
||||
}
|
||||
}
|
||||
|
||||
func TestInt64(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
x := int64(375134875348)
|
||||
b := Int64ToBytes(x, []byte{})
|
||||
v := BytesToInt64(b)
|
||||
assert.Int64(x).Equals(v)
|
||||
assert(x, Equals, v)
|
||||
}
|
||||
|
@ -6,15 +6,15 @@ import (
|
||||
"v2ray.com/core/common"
|
||||
"v2ray.com/core/common/buf"
|
||||
. "v2ray.com/core/common/serial"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestUint32(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
x := uint32(458634234)
|
||||
s1 := Uint32ToBytes(x, []byte{})
|
||||
s2 := buf.New()
|
||||
common.Must(s2.AppendSupplier(WriteUint32(x)))
|
||||
assert.Bytes(s1).Equals(s2.Bytes())
|
||||
assert(s1, Equals, s2.Bytes())
|
||||
}
|
||||
|
@ -4,13 +4,13 @@ import (
|
||||
"testing"
|
||||
|
||||
. "v2ray.com/core/common/serial"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestGetInstance(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
p, err := GetInstance("")
|
||||
assert.Pointer(p).IsNil()
|
||||
assert.Error(err).IsNotNil()
|
||||
assert(p, IsNil)
|
||||
assert(err, IsNotNil)
|
||||
}
|
||||
|
@ -6,11 +6,11 @@ import (
|
||||
"testing"
|
||||
|
||||
. "v2ray.com/core/common/signal"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestErrorOrFinish2_Error(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
c1 := make(chan error, 1)
|
||||
c2 := make(chan error, 2)
|
||||
@ -22,11 +22,11 @@ func TestErrorOrFinish2_Error(t *testing.T) {
|
||||
|
||||
c1 <- errors.New("test")
|
||||
err := <-c
|
||||
assert.String(err.Error()).Equals("test")
|
||||
assert(err.Error(), Equals, "test")
|
||||
}
|
||||
|
||||
func TestErrorOrFinish2_Error2(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
c1 := make(chan error, 1)
|
||||
c2 := make(chan error, 2)
|
||||
@ -38,11 +38,11 @@ func TestErrorOrFinish2_Error2(t *testing.T) {
|
||||
|
||||
c2 <- errors.New("test")
|
||||
err := <-c
|
||||
assert.String(err.Error()).Equals("test")
|
||||
assert(err.Error(), Equals, "test")
|
||||
}
|
||||
|
||||
func TestErrorOrFinish2_NoneError(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
c1 := make(chan error, 1)
|
||||
c2 := make(chan error, 2)
|
||||
@ -61,11 +61,11 @@ func TestErrorOrFinish2_NoneError(t *testing.T) {
|
||||
|
||||
close(c2)
|
||||
err := <-c
|
||||
assert.Error(err).IsNil()
|
||||
assert(err, IsNil)
|
||||
}
|
||||
|
||||
func TestErrorOrFinish2_NoneError2(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
c1 := make(chan error, 1)
|
||||
c2 := make(chan error, 2)
|
||||
@ -84,5 +84,5 @@ func TestErrorOrFinish2_NoneError2(t *testing.T) {
|
||||
|
||||
close(c1)
|
||||
err := <-c
|
||||
assert.Error(err).IsNil()
|
||||
assert(err, IsNil)
|
||||
}
|
||||
|
@ -12,8 +12,6 @@ type ActivityUpdater interface {
|
||||
type ActivityTimer struct {
|
||||
updated chan bool
|
||||
timeout chan time.Duration
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (t *ActivityTimer) Update() {
|
||||
@ -27,7 +25,7 @@ func (t *ActivityTimer) SetTimeout(timeout time.Duration) {
|
||||
t.timeout <- timeout
|
||||
}
|
||||
|
||||
func (t *ActivityTimer) run() {
|
||||
func (t *ActivityTimer) run(ctx context.Context, cancel context.CancelFunc) {
|
||||
ticker := time.NewTicker(<-t.timeout)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
@ -36,32 +34,35 @@ func (t *ActivityTimer) run() {
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-t.ctx.Done():
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case timeout := <-t.timeout:
|
||||
if timeout == 0 {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
|
||||
ticker.Stop()
|
||||
ticker = time.NewTicker(timeout)
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case <-t.updated:
|
||||
// Updated keep waiting.
|
||||
default:
|
||||
t.cancel()
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CancelAfterInactivity(ctx context.Context, timeout time.Duration) (context.Context, *ActivityTimer) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
func CancelAfterInactivity(ctx context.Context, cancel context.CancelFunc, timeout time.Duration) *ActivityTimer {
|
||||
timer := &ActivityTimer{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
timeout: make(chan time.Duration, 1),
|
||||
updated: make(chan bool, 1),
|
||||
}
|
||||
timer.timeout <- timeout
|
||||
go timer.run()
|
||||
return ctx, timer
|
||||
go timer.run(ctx, cancel)
|
||||
return timer
|
||||
}
|
||||
|
@ -7,26 +7,28 @@ import (
|
||||
"time"
|
||||
|
||||
. "v2ray.com/core/common/signal"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestActivityTimer(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
ctx, timer := CancelAfterInactivity(context.Background(), time.Second*5)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
timer := CancelAfterInactivity(ctx, cancel, time.Second*5)
|
||||
time.Sleep(time.Second * 6)
|
||||
assert.Error(ctx.Err()).IsNotNil()
|
||||
assert(ctx.Err(), IsNotNil)
|
||||
runtime.KeepAlive(timer)
|
||||
}
|
||||
|
||||
func TestActivityTimerUpdate(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
ctx, timer := CancelAfterInactivity(context.Background(), time.Second*10)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
timer := CancelAfterInactivity(ctx, cancel, time.Second*10)
|
||||
time.Sleep(time.Second * 3)
|
||||
assert.Error(ctx.Err()).IsNil()
|
||||
assert(ctx.Err(), IsNil)
|
||||
timer.SetTimeout(time.Second * 1)
|
||||
time.Sleep(time.Second * 2)
|
||||
assert.Error(ctx.Err()).IsNotNil()
|
||||
assert(ctx.Err(), IsNotNil)
|
||||
runtime.KeepAlive(timer)
|
||||
}
|
||||
|
@ -4,74 +4,74 @@ import (
|
||||
"testing"
|
||||
|
||||
. "v2ray.com/core/common/uuid"
|
||||
"v2ray.com/core/testing/assert"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestParseBytes(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
str := "2418d087-648d-4990-86e8-19dca1d006d3"
|
||||
bytes := []byte{0x24, 0x18, 0xd0, 0x87, 0x64, 0x8d, 0x49, 0x90, 0x86, 0xe8, 0x19, 0xdc, 0xa1, 0xd0, 0x06, 0xd3}
|
||||
|
||||
uuid, err := ParseBytes(bytes)
|
||||
assert.Error(err).IsNil()
|
||||
assert.String(uuid.String()).Equals(str)
|
||||
assert(err, IsNil)
|
||||
assert(uuid.String(), Equals, str)
|
||||
|
||||
_, err = ParseBytes([]byte{1, 3, 2, 4})
|
||||
assert.Error(err).IsNotNil()
|
||||
assert(err, IsNotNil)
|
||||
}
|
||||
|
||||
func TestParseString(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
str := "2418d087-648d-4990-86e8-19dca1d006d3"
|
||||
expectedBytes := []byte{0x24, 0x18, 0xd0, 0x87, 0x64, 0x8d, 0x49, 0x90, 0x86, 0xe8, 0x19, 0xdc, 0xa1, 0xd0, 0x06, 0xd3}
|
||||
|
||||
uuid, err := ParseString(str)
|
||||
assert.Error(err).IsNil()
|
||||
assert.Bytes(uuid.Bytes()).Equals(expectedBytes)
|
||||
assert(err, IsNil)
|
||||
assert(uuid.Bytes(), Equals, expectedBytes)
|
||||
|
||||
uuid, err = ParseString("2418d087")
|
||||
assert.Error(err).IsNotNil()
|
||||
assert(err, IsNotNil)
|
||||
|
||||
uuid, err = ParseString("2418d087-648k-4990-86e8-19dca1d006d3")
|
||||
assert.Error(err).IsNotNil()
|
||||
assert(err, IsNotNil)
|
||||
}
|
||||
|
||||
func TestNewUUID(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
uuid := New()
|
||||
uuid2, err := ParseString(uuid.String())
|
||||
|
||||
assert.Error(err).IsNil()
|
||||
assert.String(uuid.String()).Equals(uuid2.String())
|
||||
assert.Bytes(uuid.Bytes()).Equals(uuid2.Bytes())
|
||||
assert(err, IsNil)
|
||||
assert(uuid.String(), Equals, uuid2.String())
|
||||
assert(uuid.Bytes(), Equals, uuid2.Bytes())
|
||||
}
|
||||
|
||||
func TestRandom(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
uuid := New()
|
||||
uuid2 := New()
|
||||
|
||||
assert.String(uuid.String()).NotEquals(uuid2.String())
|
||||
assert.Bytes(uuid.Bytes()).NotEquals(uuid2.Bytes())
|
||||
assert(uuid.String(), NotEquals, uuid2.String())
|
||||
assert(uuid.Bytes(), NotEquals, uuid2.Bytes())
|
||||
}
|
||||
|
||||
func TestEquals(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
var uuid *UUID = nil
|
||||
var uuid2 *UUID = nil
|
||||
assert.Bool(uuid.Equals(uuid2)).IsTrue()
|
||||
assert.Bool(uuid.Equals(New())).IsFalse()
|
||||
assert(uuid.Equals(uuid2), IsTrue)
|
||||
assert(uuid.Equals(New()), IsFalse)
|
||||
}
|
||||
|
||||
func TestNext(t *testing.T) {
|
||||
assert := assert.On(t)
|
||||
assert := With(t)
|
||||
|
||||
uuid := New()
|
||||
uuid2 := uuid.Next()
|
||||
assert.Bool(uuid.Equals(uuid2)).IsFalse()
|
||||
assert(uuid.Equals(uuid2), IsFalse)
|
||||
}
|
||||
|
4
core.go
4
core.go
@ -18,9 +18,9 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
version = "2.41"
|
||||
version = "3.1"
|
||||
build = "Custom"
|
||||
codename = "One for all"
|
||||
codename = "die Commanderin"
|
||||
intro = "An unified platform for anti-censorship."
|
||||
)
|
||||
|
||||
|
10
loader.go
10
loader.go
@ -2,10 +2,11 @@ package core
|
||||
|
||||
import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
|
||||
"v2ray.com/core/common"
|
||||
"v2ray.com/core/common/buf"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"v2ray.com/core/common"
|
||||
)
|
||||
|
||||
// ConfigLoader is an utility to load V2Ray config from external source.
|
||||
@ -30,7 +31,10 @@ func LoadConfig(format ConfigFormat, input io.Reader) (*Config, error) {
|
||||
|
||||
func loadProtobufConfig(input io.Reader) (*Config, error) {
|
||||
config := new(Config)
|
||||
data, _ := ioutil.ReadAll(input)
|
||||
data, err := buf.ReadAllToBytes(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := proto.Unmarshal(data, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1,5 +1,47 @@
|
||||
// +build json
|
||||
|
||||
package main
|
||||
|
||||
import _ "v2ray.com/core/tools/conf"
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
||||
"v2ray.com/core"
|
||||
"v2ray.com/core/common/platform"
|
||||
)
|
||||
|
||||
func jsonToProto(input io.Reader) (*core.Config, error) {
|
||||
v2ctl := platform.GetToolLocation("v2ctl")
|
||||
_, err := os.Stat(v2ctl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cmd := exec.Command(v2ctl, "config")
|
||||
cmd.Stdin = input
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
stdoutReader, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer stdoutReader.Close()
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config, err := core.LoadConfig(core.ConfigFormat_Protobuf, stdoutReader)
|
||||
|
||||
cmd.Wait()
|
||||
|
||||
return config, err
|
||||
}
|
||||
|
||||
func init() {
|
||||
core.RegisterConfigLoader(core.ConfigFormat_JSON, func(input io.Reader) (*core.Config, error) {
|
||||
config, err := jsonToProto(input)
|
||||
if err != nil {
|
||||
return nil, newError("failed to execute v2ctl to convert config file.").Base(err)
|
||||
}
|
||||
return config, nil
|
||||
})
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
// The following are necessary as they register handlers in their init functions.
|
||||
_ "v2ray.com/core/app/dispatcher/impl"
|
||||
_ "v2ray.com/core/app/dns/server"
|
||||
_ "v2ray.com/core/app/policy/manager"
|
||||
_ "v2ray.com/core/app/proxyman/inbound"
|
||||
_ "v2ray.com/core/app/proxyman/outbound"
|
||||
_ "v2ray.com/core/app/router"
|
||||
|
15
main/main.go
15
main/main.go
@ -22,6 +22,7 @@ var (
|
||||
version = flag.Bool("version", false, "Show current version of V2Ray.")
|
||||
test = flag.Bool("test", false, "Test config file only, without launching V2Ray server.")
|
||||
format = flag.String("format", "json", "Format of input file.")
|
||||
plugin = flag.Bool("plugin", false, "True to load plugins.")
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -67,7 +68,7 @@ func startV2Ray() (core.Server, error) {
|
||||
|
||||
server, err := core.New(config)
|
||||
if err != nil {
|
||||
return nil, newError("failed to create initialize").Base(err)
|
||||
return nil, newError("failed to create server").Base(err)
|
||||
}
|
||||
|
||||
return server, nil
|
||||
@ -82,19 +83,27 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
if *plugin {
|
||||
if err := core.LoadPlugins(); err != nil {
|
||||
fmt.Println("Failed to load plugins:", err.Error())
|
||||
os.Exit(-1)
|
||||
}
|
||||
}
|
||||
|
||||
server, err := startV2Ray()
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
return
|
||||
os.Exit(-1)
|
||||
}
|
||||
|
||||
if *test {
|
||||
fmt.Println("Configuration OK.")
|
||||
return
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if err := server.Start(); err != nil {
|
||||
fmt.Println("Failed to start", err)
|
||||
os.Exit(-1)
|
||||
}
|
||||
|
||||
osSignals := make(chan os.Signal, 1)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user