commit 2e4aea518881dc4631dcf8018701644dada42f41 Author: wenyifan Date: Wed Aug 3 10:04:41 2022 +0800 init diff --git a/.config/bypass.txt b/.config/bypass.txt new file mode 100644 index 0000000..e8471dd --- /dev/null +++ b/.config/bypass.txt @@ -0,0 +1,34 @@ +# period for live reloading +reload 10s + +# matcher reversed + reverse true + +*.example.com + +# this will match example.org and *.example.org +.example.org + +# From IANA IPv4 Special-Purpose Address Registry +# http://www.iana.org/assignments/iana-ipv4-special-registry/iana-ipv4-special-registry.xhtml + +0.0.0.0/8 # RFC1122: "This host on this network" +10.0.0.0/8 # RFC1918: Private-Use +100.64.0.0/10 # RFC6598: Shared Address Space +127.0.0.0/8 # RFC1122: Loopback +169.254.0.0/16 # RFC3927: Link Local +172.16.0.0/12 # RFC1918: Private-Use +192.0.0.0/24 # RFC6890: IETF Protocol Assignments +192.0.2.0/24 # RFC5737: Documentation (TEST-NET-1) +192.88.99.0/24 # RFC3068: 6to4 Relay Anycast +192.168.0.0/16 # RFC1918: Private-Use +198.18.0.0/15 # RFC2544: Benchmarking +198.51.100.0/24 # RFC5737: Documentation (TEST-NET-2) +203.0.113.0/24 # RFC5737: Documentation (TEST-NET-3) +240.0.0.0/4 # RFC1112: Reserved +255.255.255.255/32 # RFC0919: Limited Broadcast + +# From IANA Multicast Address Space Registry +# http://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml + +224.0.0.0/4 # RFC5771: Multicast/Reserved diff --git a/.config/dns.txt b/.config/dns.txt new file mode 100644 index 0000000..d52f67f --- /dev/null +++ b/.config/dns.txt @@ -0,0 +1,19 @@ +# resolver timeout, default 30s. +timeout 10s + +# resolver cache TTL, +# minus value means that cache is disabled, +# default to the TTL in DNS server response. +# ttl 300s + +# period for live reloading +reload 10s + +# ip[:port] [protocol] [hostname] + +https://1.0.0.1/dns-query +1.1.1.1:853 tls cloudflare-dns.com +8.8.8.8 +8.8.8.8 tcp +1.1.1.1 udp +1.1.1.1:53 tcp \ No newline at end of file diff --git a/.config/gost.json b/.config/gost.json new file mode 100644 index 0000000..07072df --- /dev/null +++ b/.config/gost.json @@ -0,0 +1,28 @@ +{ + "Retries": 1, + "Debug": false, + "ServeNodes": [ + ":12345" + ], + "ChainNodes": [ + "http://:8080" + ], + + "Routes": [ + { + "Retries": 1, + "ServeNodes": [ + "ws://:1443" + ], + "ChainNodes": [ + "socks://:192.168.1.1:1080" + ] + }, + { + "Retries": 3, + "ServeNodes": [ + "quic://:443" + ] + } + ] +} \ No newline at end of file diff --git a/.config/hosts.txt b/.config/hosts.txt new file mode 100644 index 0000000..a2bf11d --- /dev/null +++ b/.config/hosts.txt @@ -0,0 +1,17 @@ +# period for live reloading +reload 10s + +# The following lines are desirable for IPv4 capable hosts +127.0.0.1 localhost + +# 127.0.1.1 is often used for the FQDN of the machine +127.0.1.1 thishost.mydomain.org thishost +192.168.1.10 foo.mydomain.org foo +192.168.1.13 bar.mydomain.org bar +146.82.138.7 master.debian.org master +209.237.226.90 www.opensource.org + +# The following lines are desirable for IPv6 capable hosts +::1 localhost ip6-localhost ip6-loopback +ff02::1 ip6-allnodes +ff02::2 ip6-allrouters diff --git a/.config/kcp.json b/.config/kcp.json new file mode 100644 index 0000000..00a576a --- /dev/null +++ b/.config/kcp.json @@ -0,0 +1,21 @@ +{ + "key": "it's a secrect", + "crypt": "aes", + "mode": "fast", + "mtu" : 1350, + "sndwnd": 1024, + "rcvwnd": 1024, + "datashard": 10, + "parityshard": 3, + "dscp": 0, + "nocomp": false, + "acknodelay": false, + "nodelay": 0, + "interval": 40, + "resend": 0, + "nc": 0, + "sockbuf": 4194304, + "keepalive": 10, + "snmplog": "", + "snmpperiod": 60 +} \ No newline at end of file diff --git a/.config/peer.txt b/.config/peer.txt new file mode 100644 index 0000000..eb87043 --- /dev/null +++ b/.config/peer.txt @@ -0,0 +1,14 @@ +# strategy for node selecting +strategy random + +max_fails 1 + +fail_timeout 30s + +# period for live reloading +reload 10s + +# peers +peer http://:18080 +peer socks://:11080 +peer ss://chacha20:123456@:18338 \ No newline at end of file diff --git a/.config/probe_resist.txt b/.config/probe_resist.txt new file mode 100644 index 0000000..c57eff5 --- /dev/null +++ b/.config/probe_resist.txt @@ -0,0 +1 @@ +Hello World! \ No newline at end of file diff --git a/.config/secrets.txt b/.config/secrets.txt new file mode 100644 index 0000000..fe86322 --- /dev/null +++ b/.config/secrets.txt @@ -0,0 +1,11 @@ +# period for live reloading +reload 3s + +# username password + +$test.admin$ $123456$ +@test.admin@ @123456@ +test.admin# #123456# +test.admin\admin 123456 +test001 123456 +test002 12345678 \ No newline at end of file diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..c121e41 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,25 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test +release +debian +docs + +*.exe +*.test + +*.bak + +.git +.gitignore +LICENSE +VERSION +README.md +Changelog.md +Makefile +docker-compose.yml \ No newline at end of file diff --git a/.github/workflows/buildx.yaml b/.github/workflows/buildx.yaml new file mode 100644 index 0000000..0ea67d2 --- /dev/null +++ b/.github/workflows/buildx.yaml @@ -0,0 +1,74 @@ +# ref: https://github.com/crazy-max/diun/blob/master/.github/workflows/build.yml + +name: Docker +on: [push] +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Prepare + id: prepare + run: | + if [[ $GITHUB_REF == refs/tags/* ]]; then + echo ::set-output name=version::${GITHUB_REF#refs/tags/v} + elif [[ $GITHUB_REF == refs/heads/master ]]; then + echo ::set-output name=version::latest + elif [[ $GITHUB_REF == refs/heads/* ]]; then + echo ::set-output name=version::${GITHUB_REF#refs/heads/} + else + echo ::set-output name=version::snapshot + fi + + echo ::set-output name=docker_platforms::linux/amd64,linux/arm/v6,linux/arm/v7,linux/arm64/v8,linux/386,linux/s390x + echo ::set-output name=docker_image::${{ secrets.DOCKER_USERNAME }}/${{ github.event.repository.name }} + + # https://github.com/crazy-max/ghaction-docker-buildx + - name: Set up Docker Buildx + id: buildx + uses: crazy-max/ghaction-docker-buildx@v1 + with: + version: latest + + - name: Environment + run: | + echo home=$HOME + echo git_ref=$GITHUB_REF + echo git_sha=$GITHUB_SHA + echo version=${{ steps.prepare.outputs.version }} + echo image=${{ steps.prepare.outputs.docker_image }} + echo platforms=${{ steps.prepare.outputs.docker_platforms }} + echo avail_platforms=${{ steps.buildx.outputs.platforms }} + + # https://github.com/actions/checkout + - name: Checkout + uses: actions/checkout@v2 + + - name: Docker Buildx (no push) + run: | + docker buildx bake \ + --set ${{ github.event.repository.name }}.platform=${{ steps.prepare.outputs.docker_platforms }} \ + --set ${{ github.event.repository.name }}.output=type=image,push=false \ + --set ${{ github.event.repository.name }}.tags="${{ steps.prepare.outputs.docker_image }}:${{ steps.prepare.outputs.version }}" \ + --file docker-compose.yaml + + - name: Docker Login + if: success() + env: + DOCKER_PASSWORD: ${{ secrets.DOCKER_PASSWORD }} + run: | + echo "${DOCKER_PASSWORD}" | docker login --username "${{ secrets.DOCKER_USERNAME }}" --password-stdin + + - name: Docker Buildx (push) + if: success() + run: | + docker buildx bake \ + --set ${{ github.event.repository.name }}.platform=${{ steps.prepare.outputs.docker_platforms }} \ + --set ${{ github.event.repository.name }}.output=type=image,push=true \ + --set ${{ github.event.repository.name }}.tags="${{ steps.prepare.outputs.docker_image }}:${{ steps.prepare.outputs.version }}" \ + --file docker-compose.yaml + + - name: Clear + if: always() + run: | + rm -f ${HOME}/.docker/config.json + diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4726eff --- /dev/null +++ b/.gitignore @@ -0,0 +1,37 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test +release +debian +bin +.idea + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.swp +*.swo + +*.exe +*.test + +*.bak + +.vscode/ +cmd/gost/gost +cmd/gost/.ssl +snap diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..772c62f --- /dev/null +++ b/.travis.yml @@ -0,0 +1,12 @@ +language: go +sudo: false +go: + - 1.x + +install: true +script: + - go test -race -v -coverprofile=coverage.txt -covermode=atomic + - cd cmd/gost && go build + +after_success: + - bash <(curl -s https://codecov.io/bash) diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..4a0a4a9 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,25 @@ +FROM --platform=$BUILDPLATFORM golang:1.18-alpine as builder + +# Convert TARGETPLATFORM to GOARCH format +# https://github.com/tonistiigi/xx +COPY --from=tonistiigi/xx:golang / / + +ARG TARGETPLATFORM + +RUN apk add --no-cache musl-dev git gcc + +ADD . /src + +WORKDIR /src + +ENV GO111MODULE=on + +RUN cd cmd/gost && go env && go build -v + +FROM alpine:latest + +WORKDIR /bin/ + +COPY --from=builder /src/cmd/gost/gost . + +ENTRYPOINT ["/bin/gost"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..2033b3a --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2016 ginuerzh + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..eb37053 --- /dev/null +++ b/Makefile @@ -0,0 +1,103 @@ +NAME=gost +BINDIR=bin +VERSION=$(shell cat gost.go | grep 'Version =' | sed 's/.*\"\(.*\)\".*/\1/g') +GOBUILD=CGO_ENABLED=0 go build --ldflags="-s -w" -v -x -a +GOFILES=cmd/gost/*.go + +PLATFORM_LIST = \ + darwin-amd64 \ + darwin-arm64 \ + linux-386 \ + linux-amd64 \ + linux-armv5 \ + linux-armv6 \ + linux-armv7 \ + linux-armv8 \ + linux-mips-softfloat \ + linux-mips-hardfloat \ + linux-mipsle-softfloat \ + linux-mipsle-hardfloat \ + linux-mips64 \ + linux-mips64le \ + linux-s390x \ + freebsd-386 \ + freebsd-amd64 + +WINDOWS_ARCH_LIST = \ + windows-386 \ + windows-amd64 + +all: linux-amd64 darwin-amd64 windows-amd64 # Most used + +darwin-amd64: + GOARCH=amd64 GOOS=darwin $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) + +darwin-arm64: + GOARCH=arm64 GOOS=darwin $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) + +linux-386: + GOARCH=386 GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) + +linux-amd64: + GOARCH=amd64 GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) + +linux-armv5: + GOARCH=arm GOOS=linux GOARM=5 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) + +linux-armv6: + GOARCH=arm GOOS=linux GOARM=6 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) + +linux-armv7: + GOARCH=arm GOOS=linux GOARM=7 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) + +linux-armv8: + GOARCH=arm64 GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) + +linux-mips-softfloat: + GOARCH=mips GOMIPS=softfloat GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) + +linux-mips-hardfloat: + GOARCH=mips GOMIPS=hardfloat GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) + +linux-mipsle-softfloat: + GOARCH=mipsle GOMIPS=softfloat GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) + +linux-mipsle-hardfloat: + GOARCH=mipsle GOMIPS=hardfloat GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) + +linux-mips64: + GOARCH=mips64 GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) + +linux-mips64le: + GOARCH=mips64le GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) + +linux-s390x: + GOARCH=s390x GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) + +freebsd-386: + GOARCH=386 GOOS=freebsd $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) + +freebsd-amd64: + GOARCH=amd64 GOOS=freebsd $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) + +windows-386: + GOARCH=386 GOOS=windows $(GOBUILD) -o $(BINDIR)/$(NAME)-$@.exe $(GOFILES) + +windows-amd64: + GOARCH=amd64 GOOS=windows $(GOBUILD) -o $(BINDIR)/$(NAME)-$@.exe $(GOFILES) + +gz_releases=$(addsuffix .gz, $(PLATFORM_LIST)) +zip_releases=$(addsuffix .zip, $(WINDOWS_ARCH_LIST)) + +$(gz_releases): %.gz : % + chmod +x $(BINDIR)/$(NAME)-$(basename $@) + gzip -f -S -$(VERSION).gz $(BINDIR)/$(NAME)-$(basename $@) + +$(zip_releases): %.zip : % + zip -m -j $(BINDIR)/$(NAME)-$(basename $@)-$(VERSION).zip $(BINDIR)/$(NAME)-$(basename $@).exe + +all-arch: $(PLATFORM_LIST) $(WINDOWS_ARCH_LIST) + +releases: $(gz_releases) $(zip_releases) +clean: + rm $(BINDIR)/* diff --git a/README.md b/README.md new file mode 100644 index 0000000..88aef0e --- /dev/null +++ b/README.md @@ -0,0 +1,367 @@ +GO Simple Tunnel +====== + +### GO语言实现的安全隧道 + +[![GoDoc](https://godoc.org/github.com/ginuerzh/gost?status.svg)](https://godoc.org/github.com/ginuerzh/gost) +[![Go Report Card](https://goreportcard.com/badge/github.com/ginuerzh/gost)](https://goreportcard.com/report/github.com/ginuerzh/gost) +[![codecov](https://codecov.io/gh/ginuerzh/gost/branch/master/graphs/badge.svg)](https://codecov.io/gh/ginuerzh/gost/branch/master) +[![GitHub release](https://img.shields.io/github/release/ginuerzh/gost.svg)](https://github.com/ginuerzh/gost/releases/latest) +[![Docker](https://img.shields.io/docker/pulls/ginuerzh/gost.svg)](https://hub.docker.com/r/ginuerzh/gost/) + +[English README](README_en.md) + +### !!![V3版本已经可用,欢迎抢先体验](https://latest.gost.run)!!! + +特性 +------ + +* 多端口监听 +* 可设置转发代理,支持多级转发(代理链) +* 支持标准HTTP/HTTPS/HTTP2/SOCKS4(A)/SOCKS5代理协议 +* Web代理支持[探测防御](https://v2.gost.run/probe_resist/) +* [支持多种隧道类型](https://v2.gost.run/configuration/) +* [SOCKS5代理支持TLS协商加密](https://v2.gost.run/socks/) +* [Tunnel UDP over TCP](https://v2.gost.run/socks/) +* [TCP/UDP透明代理](https://v2.gost.run/redirect/) +* [本地/远程TCP/UDP端口转发](https://v2.gost.run/port-forwarding/) +* [支持Shadowsocks(TCP/UDP)协议](https://v2.gost.run/ss/) +* [支持SNI代理](https://v2.gost.run/sni/) +* [权限控制](https://v2.gost.run/permission/) +* [负载均衡](https://v2.gost.run/load-balancing/) +* [路由控制](https://v2.gost.run/bypass/) +* DNS[解析](https://v2.gost.run/resolver/)和[代理](https://v2.gost.run/dns/) +* [TUN/TAP设备](https://v2.gost.run/tuntap/) + +Wiki站点: [v2.gost.run](https://v2.gost.run) + +Telegram讨论群: + +Google讨论组: + +安装 +------ + +#### 二进制文件 + + + +#### 源码编译 + +```bash +git clone https://github.com/ginuerzh/gost.git +cd gost/cmd/gost +go build +``` + +#### Docker + +```bash +docker pull ginuerzh/gost +``` + +#### Homebrew + +```bash +brew install gost +``` + +#### Ubuntu商店 + +```bash +sudo snap install core +sudo snap install gost +``` + +快速上手 +------ + +#### 不设置转发代理 + + + +* 作为标准HTTP/SOCKS5代理 + +```bash +gost -L=:8080 +``` + +* 设置代理认证信息 + +```bash +gost -L=admin:123456@localhost:8080 +``` + +* 多端口监听 + +```bash +gost -L=http2://:443 -L=socks5://:1080 -L=ss://aes-128-cfb:123456@:8338 +``` + +#### 设置转发代理 + + + +```bash +gost -L=:8080 -F=192.168.1.1:8081 +``` + +* 转发代理认证 + +```bash +gost -L=:8080 -F=http://admin:123456@192.168.1.1:8081 +``` + +#### 设置多级转发代理(代理链) + + + +```bash +gost -L=:8080 -F=quic://192.168.1.1:6121 -F=socks5+wss://192.168.1.2:1080 -F=http2://192.168.1.3:443 ... -F=a.b.c.d:NNNN +``` + +gost按照-F设置的顺序通过代理链将请求最终转发给a.b.c.d:NNNN处理,每一个转发代理可以是任意HTTP/HTTPS/HTTP2/SOCKS4/SOCKS5/Shadowsocks类型代理。 + +#### 本地端口转发(TCP) + +```bash +gost -L=tcp://:2222/192.168.1.1:22 [-F=...] +``` + +将本地TCP端口2222上的数据(通过代理链)转发到192.168.1.1:22上。当代理链末端(最后一个-F参数)为SSH转发通道类型时,gost会直接使用SSH的本地端口转发功能: + +```bash +gost -L=tcp://:2222/192.168.1.1:22 -F forward+ssh://:2222 +``` + +#### 本地端口转发(UDP) + +```bash +gost -L=udp://:5353/192.168.1.1:53?ttl=60 [-F=...] +``` + +将本地UDP端口5353上的数据(通过代理链)转发到192.168.1.1:53上。 +每条转发通道都有超时时间,当超过此时间,且在此时间段内无任何数据交互,则此通道将关闭。可以通过`ttl`参数来设置超时时间,默认值为60秒。 + +**注:** 转发UDP数据时,如果有代理链,则代理链的末端(最后一个-F参数)必须是gost SOCKS5类型代理,gost会使用UDP over TCP方式进行转发。 + +#### 远程端口转发(TCP) + +```bash +gost -L=rtcp://:2222/192.168.1.1:22 [-F=... -F=socks5://172.24.10.1:1080] +``` +将172.24.10.1:2222上的数据(通过代理链)转发到192.168.1.1:22上。当代理链末端(最后一个-F参数)为SSH转发通道类型时,gost会直接使用SSH的远程端口转发功能: + +```bash +gost -L=rtcp://:2222/192.168.1.1:22 -F forward+ssh://:2222 +``` + +#### 远程端口转发(UDP) + +```bash +gost -L=rudp://:5353/192.168.1.1:53?ttl=60 [-F=... -F=socks5://172.24.10.1:1080] +``` +将172.24.10.1:5353上的数据(通过代理链)转发到192.168.1.1:53上。 +每条转发通道都有超时时间,当超过此时间,且在此时间段内无任何数据交互,则此通道将关闭。可以通过`ttl`参数来设置超时时间,默认值为60秒。 + +**注:** 转发UDP数据时,如果有代理链,则代理链的末端(最后一个-F参数)必须是GOST SOCKS5类型代理,gost会使用UDP-over-TCP方式进行转发。 + +#### HTTP2 + +gost的HTTP2支持两种模式: +* 作为标准的HTTP2代理,并向下兼容HTTPS代理。 +* 作为通道传输其他协议。 + +##### 代理模式 +服务端: +```bash +gost -L=http2://:443 +``` +客户端: +```bash +gost -L=:8080 -F=http2://server_ip:443 +``` + +##### 通道模式 +服务端: +```bash +gost -L=h2://:443 +``` +客户端: +```bash +gost -L=:8080 -F=h2://server_ip:443 +``` + +#### QUIC +gost对QUIC的支持是基于[quic-go](https://github.com/lucas-clemente/quic-go)库。 + +服务端: +```bash +gost -L=quic://:6121 +``` + +客户端: +```bash +gost -L=:8080 -F=quic://server_ip:6121 +``` + +**注:** QUIC模式只能作为代理链的第一个节点。 + +#### KCP +gost对KCP的支持是基于[kcp-go](https://github.com/xtaci/kcp-go)和[kcptun](https://github.com/xtaci/kcptun)库。 + +服务端: +```bash +gost -L=kcp://:8388 +``` + +客户端: +```bash +gost -L=:8080 -F=kcp://server_ip:8388 +``` + +gost会自动加载当前工作目录中的kcp.json(如果存在)配置文件,或者可以手动通过参数指定配置文件路径: +```bash +gost -L=kcp://:8388?c=/path/to/conf/file +``` + +**注:** KCP模式只能作为代理链的第一个节点。 + +#### SSH + +gost的SSH支持两种模式: +* 作为转发通道,配合本地/远程TCP端口转发使用。 +* 作为通道传输其他协议。 + +##### 转发模式 +服务端: +```bash +gost -L=forward+ssh://:2222 +``` +客户端: +```bash +gost -L=rtcp://:1222/:22 -F=forward+ssh://server_ip:2222 +``` + +##### 通道模式 +服务端: +```bash +gost -L=ssh://:2222 +``` +客户端: +```bash +gost -L=:8080 -F=ssh://server_ip:2222?ping=60 +``` + +可以通过`ping`参数设置心跳包发送周期,单位为秒。默认不发送心跳包。 + + +#### 透明代理 +基于iptables的透明代理。 + +```bash +gost -L=redirect://:12345 -F=http2://server_ip:443 +``` + +#### obfs4 +此功能由[@isofew](https://github.com/isofew)贡献。 + +服务端: +```bash +gost -L=obfs4://:443 +``` + +当服务端运行后会在控制台打印出连接地址供客户端使用: +``` +obfs4://:443/?cert=4UbQjIfjJEQHPOs8vs5sagrSXx1gfrDCGdVh2hpIPSKH0nklv1e4f29r7jb91VIrq4q5Jw&iat-mode=0 +``` + +客户端: +``` +gost -L=:8888 -F='obfs4://server_ip:443?cert=4UbQjIfjJEQHPOs8vs5sagrSXx1gfrDCGdVh2hpIPSKH0nklv1e4f29r7jb91VIrq4q5Jw&iat-mode=0' +``` + +加密机制 +------ + +#### HTTP + +对于HTTP可以使用TLS加密整个通讯过程,即HTTPS代理: + +服务端: + +```bash +gost -L=https://:443 +``` +客户端: + +```bash +gost -L=:8080 -F=http+tls://server_ip:443 +``` + +#### HTTP2 + +gost的HTTP2代理模式仅支持使用TLS加密的HTTP2协议,不支持明文HTTP2传输。 + +gost的HTTP2通道模式支持加密(h2)和明文(h2c)两种模式。 + +#### SOCKS5 + +gost支持标准SOCKS5协议的no-auth(0x00)和user/pass(0x02)方法,并在此基础上扩展了两个:tls(0x80)和tls-auth(0x82),用于数据加密。 + +服务端: + +```bash +gost -L=socks5://:1080 +``` + +客户端: + +```bash +gost -L=:8080 -F=socks5://server_ip:1080 +``` + +如果两端都是gost(如上)则数据传输会被加密(协商使用tls或tls-auth方法),否则使用标准SOCKS5进行通讯(no-auth或user/pass方法)。 + +#### Shadowsocks +gost对shadowsocks的支持是基于[shadowsocks-go](https://github.com/shadowsocks/shadowsocks-go)库。 + +服务端: + +```bash +gost -L=ss://chacha20:123456@:8338 +``` +客户端: + +```bash +gost -L=:8080 -F=ss://chacha20:123456@server_ip:8338 +``` + +##### Shadowsocks UDP relay + +目前仅服务端支持UDP Relay。 + +服务端: + +```bash +gost -L=ssu://chacha20:123456@:8338 +``` + +#### TLS +gost内置了TLS证书,如果需要使用其他TLS证书,有两种方法: +* 在gost运行目录放置cert.pem(公钥)和key.pem(私钥)两个文件即可,gost会自动加载运行目录下的cert.pem和key.pem文件。 +* 使用参数指定证书文件路径: +```bash +gost -L="http2://:443?cert=/path/to/my/cert/file&key=/path/to/my/key/file" +``` + +对于客户端可以通过`secure`参数开启服务器证书和域名校验: +```bash +gost -L=:8080 -F="http2://server_domain_name:443?secure=true" +``` + +对于客户端可以指定CA证书进行[证书锁定](https://en.wikipedia.org/wiki/Transport_Layer_Security#Certificate_pinning)(Certificate Pinning): +```bash +gost -L=:8080 -F="http2://:443?ca=ca.pem" +``` +证书锁定功能由[@sheerun](https://github.com/sheerun)贡献 diff --git a/README_en.md b/README_en.md new file mode 100644 index 0000000..34c9c86 --- /dev/null +++ b/README_en.md @@ -0,0 +1,421 @@ +gost - GO Simple Tunnel +====== + +### A simple security tunnel written in Golang + +[![GoDoc](https://godoc.org/github.com/ginuerzh/gost?status.svg)](https://godoc.org/github.com/ginuerzh/gost) +[![Build Status](https://travis-ci.org/ginuerzh/gost.svg?branch=master)](https://travis-ci.org/ginuerzh/gost) +[![Go Report Card](https://goreportcard.com/badge/github.com/ginuerzh/gost)](https://goreportcard.com/report/github.com/ginuerzh/gost) +[![codecov](https://codecov.io/gh/ginuerzh/gost/branch/master/graphs/badge.svg)](https://codecov.io/gh/ginuerzh/gost/branch/master) +[![GitHub release](https://img.shields.io/github/release/ginuerzh/gost.svg)](https://github.com/ginuerzh/gost/releases/latest) +[![Snap Status](https://build.snapcraft.io/badge/ginuerzh/gost.svg)](https://build.snapcraft.io/user/ginuerzh/gost) +[![Docker Build Status](https://img.shields.io/docker/build/ginuerzh/gost.svg)](https://hub.docker.com/r/ginuerzh/gost/) + +Features +------ +* Listening on multiple ports +* Multi-level forward proxy - proxy chain +* Standard HTTP/HTTPS/HTTP2/SOCKS4(A)/SOCKS5 proxy protocols support +* [Probing resistance](https://docs.ginuerzh.xyz/gost/en/probe_resist/) support for web proxy +* [Support multiple tunnel types](https://docs.ginuerzh.xyz/gost/en/configuration/) +* [TLS encryption via negotiation support for SOCKS5 proxy](https://docs.ginuerzh.xyz/gost/en/socks/) +* [Tunnel UDP over TCP](https://docs.ginuerzh.xyz/gost/en/socks/) +* [TCP/UDP Transparent proxy](https://docs.ginuerzh.xyz/gost/en/redirect/) +* [Local/remote TCP/UDP port forwarding](https://docs.ginuerzh.xyz/gost/en/port-forwarding/) +* [Shadowsocks protocol](https://docs.ginuerzh.xyz/gost/en/ss/) +* [SNI proxy](https://docs.ginuerzh.xyz/gost/en/sni/) +* [Permission control](https://docs.ginuerzh.xyz/gost/en/permission/) +* [Load balancing](https://docs.ginuerzh.xyz/gost/en/load-balancing/) +* [Routing control](https://docs.ginuerzh.xyz/gost/en/bypass/) +* DNS [resolver](https://docs.ginuerzh.xyz/gost/resolver/) and [proxy](https://docs.ginuerzh.xyz/gost/dns/) +* [TUN/TAP device](https://docs.ginuerzh.xyz/gost/en/tuntap/) + +Wiki: + +Telegram group: + +Google group: + +Installation +------ + +#### Binary files + + + +#### From source + +```bash +git clone https://github.com/ginuerzh/gost.git +cd gost/cmd/gost +go build +``` + +#### Docker + +```bash +docker pull ginuerzh/gost +``` + +#### Homebrew + +```bash +brew install gost +``` + +#### Ubuntu store + +```bash +sudo snap install core +sudo snap install gost +``` + +Getting started +------ + +#### No forward proxy + + + +* Standard HTTP/SOCKS5 proxy + +```bash +gost -L=:8080 +``` + +* Proxy authentication + +```bash +gost -L=admin:123456@localhost:8080 +``` + +* Multiple sets of authentication information + +```bash +gost -L=localhost:8080?secrets=secrets.txt +``` + +The secrets parameter allows you to set multiple authentication information for HTTP/SOCKS5 proxies, the format is: + +```plain +# username password + +test001 123456 +test002 12345678 +``` + +* Listen on multiple ports + +```bash +gost -L=http2://:443 -L=socks5://:1080 -L=ss://aes-128-cfb:123456@:8338 +``` + +#### Forward proxy + + + +```bash +gost -L=:8080 -F=192.168.1.1:8081 +``` + +* Forward proxy authentication + +```bash +gost -L=:8080 -F=http://admin:123456@192.168.1.1:8081 +``` + +#### Multi-level forward proxy + + + +```bash +gost -L=:8080 -F=quic://192.168.1.1:6121 -F=socks5+wss://192.168.1.2:1080 -F=http2://192.168.1.3:443 ... -F=a.b.c.d:NNNN +``` + +Gost forwards the request to a.b.c.d:NNNN through the proxy chain in the order set by -F, +each forward proxy can be any HTTP/HTTPS/HTTP2/SOCKS4/SOCKS5/Shadowsocks type. + +#### Local TCP port forwarding + +```bash +gost -L=tcp://:2222/192.168.1.1:22 [-F=...] +``` + +The data on the local TCP port 2222 is forwarded to 192.168.1.1:22 (through the proxy chain). If the last node of the chain (the last -F parameter) is a SSH forwad tunnel, then gost will use the local port forwarding function of SSH directly: + +```bash +gost -L=tcp://:2222/192.168.1.1:22 -F forward+ssh://:2222 +``` + +#### Local UDP port forwarding + +```bash +gost -L=udp://:5353/192.168.1.1:53?ttl=60 [-F=...] +``` + +The data on the local UDP port 5353 is forwarded to 192.168.1.1:53 (through the proxy chain). +Each forwarding channel has a timeout period. When this time is exceeded and there is no data interaction during this time period, the channel will be closed. The timeout value can be set by the `ttl` parameter. The default value is 60 seconds. + +**NOTE:** When forwarding UDP data, if there is a proxy chain, the end of the chain (the last -F parameter) must be gost SOCKS5 proxy, gost will use UDP-over-TCP to forward data. + +#### Remote TCP port forwarding + +```bash +gost -L=rtcp://:2222/192.168.1.1:22 [-F=... -F=socks5://172.24.10.1:1080] +``` + +The data on 172.24.10.1:2222 is forwarded to 192.168.1.1:22 (through the proxy chain). If the last node of the chain (the last -F parameter) is a SSH tunnel, then gost will use the remote port forwarding function of SSH directly: + +```bash +gost -L=rtcp://:2222/192.168.1.1:22 -F forward+ssh://:2222 +``` + +#### Remote UDP port forwarding + +```bash +gost -L=rudp://:5353/192.168.1.1:53?ttl=60 [-F=... -F=socks5://172.24.10.1:1080] +``` + +The data on 172.24.10.1:5353 is forwarded to 192.168.1.1:53 (through the proxy chain). +Each forwarding channel has a timeout period. When this time is exceeded and there is no data interaction during this time period, the channel will be closed. The timeout value can be set by the `ttl` parameter. The default value is 60 seconds. + +**NOTE:** When forwarding UDP data, if there is a proxy chain, the end of the chain (the last -F parameter) must be gost SOCKS5 proxy, gost will use UDP-over-TCP to forward data. + +#### HTTP2 + +Gost HTTP2 supports two modes: + +* As a standard HTTP2 proxy, and backwards-compatible with the HTTPS proxy. + +* As a transport tunnel. + +##### Standard proxy + +Server: + +```bash +gost -L=http2://:443 +``` + +Client: + +```bash +gost -L=:8080 -F=http2://server_ip:443?ping=30 +``` + +##### Tunnel + +Server: + +```bash +gost -L=h2://:443 +``` + +Client: + +```bash +gost -L=:8080 -F=h2://server_ip:443 +``` + +#### QUIC + +Support for QUIC is based on library [quic-go](https://github.com/lucas-clemente/quic-go). + +Server: + +```bash +gost -L=quic://:6121 +``` + +Client: + +```bash +gost -L=:8080 -F=quic://server_ip:6121 +``` + +**NOTE:** QUIC node can only be used as the first node of the proxy chain. + +#### KCP +Support for KCP is based on libraries [kcp-go](https://github.com/xtaci/kcp-go) and [kcptun](https://github.com/xtaci/kcptun). + +Server: + +```bash +gost -L=kcp://:8388 +``` + +Client: + +```bash +gost -L=:8080 -F=kcp://server_ip:8388 +``` + +Gost will automatically load kcp.json configuration file from current working directory if exists, +or you can use the parameter to specify the path to the file. + +```bash +gost -L=kcp://:8388?c=/path/to/conf/file +``` + +**NOTE:** KCP node can only be used as the first node of the proxy chain. + +#### SSH + +Gost SSH supports two modes: + +* As a forward tunnel, used by local/remote TCP port forwarding. + +* As a transport tunnel. + + +##### Forward tunnel + +Server: + +```bash +gost -L=forward+ssh://:2222 +``` + +Client: + +```bash +gost -L=rtcp://:1222/:22 -F=forward+ssh://server_ip:2222 +``` + +##### Transport tunnel +Server: + +```bash +gost -L=ssh://:2222 +``` +Client: + +```bash +gost -L=:8080 -F=ssh://server_ip:2222?ping=60 +``` + +The client supports the ping parameter to enable heartbeat detection (which is disabled by default). Parameter value represents heartbeat interval seconds. + +#### Transparent proxy +Iptables-based transparent proxy + +```bash +gost -L=redirect://:12345 -F=http2://server_ip:443 +``` + + +#### obfs4 +Contributed by [@isofew](https://github.com/isofew). + +Server: + +```bash +gost -L=obfs4://:443 +``` + +When the server is running normally, the console prints out the connection address for the client to use: + +```bash +obfs4://:443/?cert=4UbQjIfjJEQHPOs8vs5sagrSXx1gfrDCGdVh2hpIPSKH0nklv1e4f29r7jb91VIrq4q5Jw&iat-mode=0 +``` + +Client: + +```bash +gost -L=:8888 -F='obfs4://server_ip:443?cert=4UbQjIfjJEQHPOs8vs5sagrSXx1gfrDCGdVh2hpIPSKH0nklv1e4f29r7jb91VIrq4q5Jw&iat-mode=0' +``` + +Encryption Mechanism +------ + +#### HTTP + +For HTTP, you can use TLS to encrypt the entire communication process, the HTTPS proxy: + +Server: + +```bash +gost -L=http+tls://:443 +``` + +Client: + +```bash +gost -L=:8080 -F=http+tls://server_ip:443 +``` + +#### HTTP2 + +Gost HTTP2 proxy mode only supports the use of TLS encrypted HTTP2 protocol, does not support plaintext HTTP2. + +Gost HTTP2 tunnel mode supports both encryption (h2) and plaintext (h2c) modes. + +#### SOCKS5 + +Gost supports the standard SOCKS5 protocol methods: no-auth (0x00) and user/pass (0x02), +and extends two methods for data encryption: tls(0x80) and tls-auth(0x82). + +Server: + +```bash +gost -L=socks://:1080 +``` + +Client: + +```bash +gost -L=:8080 -F=socks://server_ip:1080 +``` + +If both ends are gosts (as example above), the data transfer will be encrypted (using tls or tls-auth). +Otherwise, use standard SOCKS5 for communication (no-auth or user/pass). + +#### Shadowsocks +Support for shadowsocks is based on library [shadowsocks-go](https://github.com/shadowsocks/shadowsocks-go). + +Server: + +```bash +gost -L=ss://aes-128-cfb:123456@:8338 +``` + +Client: + +```bash +gost -L=:8080 -F=ss://aes-128-cfb:123456@server_ip:8338 +``` + +##### Shadowsocks UDP relay + +Currently, only the server supports UDP Relay. + +Server: + +```bash +gost -L=ssu://aes-128-cfb:123456@:8338 +``` + +#### TLS +There is built-in TLS certificate in gost, if you need to use other TLS certificate, there are two ways: + +* Place two files cert.pem (public key) and key.pem (private key) in the current working directory, gost will automatically load them. + +* Use the parameter to specify the path to the certificate file: + +```bash +gost -L="http2://:443?cert=/path/to/my/cert/file&key=/path/to/my/key/file" +``` + +Client can specify `secure` parameter to perform server's certificate chain and host name verification: + +```bash +gost -L=:8080 -F="http2://server_domain_name:443?secure=true" +``` + +Client can specify a CA certificate to allow for [Certificate Pinning](https://en.wikipedia.org/wiki/Transport_Layer_Security#Certificate_pinning): + +```bash +gost -L=:8080 -F="http2://:443?ca=ca.pem" +``` + +Certificate Pinning is contributed by [@sheerun](https://github.com/sheerun). diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..1be96e9 --- /dev/null +++ b/auth.go @@ -0,0 +1,155 @@ +package gost + +import ( + "bufio" + "io" + "strings" + "sync" + "time" +) + +// Authenticator is an interface for user authentication. +type Authenticator interface { + Authenticate(user, password string) bool +} + +// LocalAuthenticator is an Authenticator that authenticates client by local key-value pairs. +type LocalAuthenticator struct { + kvs map[string]string + period time.Duration + stopped chan struct{} + mux sync.RWMutex +} + +// NewLocalAuthenticator creates an Authenticator that authenticates client by local infos. +func NewLocalAuthenticator(kvs map[string]string) *LocalAuthenticator { + return &LocalAuthenticator{ + kvs: kvs, + stopped: make(chan struct{}), + } +} + +// Authenticate checks the validity of the provided user-password pair. +func (au *LocalAuthenticator) Authenticate(user, password string) bool { + if au == nil { + return true + } + + au.mux.RLock() + defer au.mux.RUnlock() + + if len(au.kvs) == 0 { + return true + } + + v, ok := au.kvs[user] + return ok && (v == "" || password == v) +} + +// Add adds a key-value pair to the Authenticator. +func (au *LocalAuthenticator) Add(k, v string) { + au.mux.Lock() + defer au.mux.Unlock() + if au.kvs == nil { + au.kvs = make(map[string]string) + } + au.kvs[k] = v +} + +// Reload parses config from r, then live reloads the Authenticator. +func (au *LocalAuthenticator) Reload(r io.Reader) error { + var period time.Duration + kvs := make(map[string]string) + + if r == nil || au.Stopped() { + return nil + } + + // splitLine splits a line text by white space. + // A line started with '#' will be ignored, otherwise it is valid. + split := func(line string) []string { + if line == "" { + return nil + } + line = strings.Replace(line, "\t", " ", -1) + line = strings.TrimSpace(line) + + if strings.IndexByte(line, '#') == 0 { + return nil + } + + var ss []string + for _, s := range strings.Split(line, " ") { + if s = strings.TrimSpace(s); s != "" { + ss = append(ss, s) + } + } + return ss + } + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + ss := split(line) + if len(ss) == 0 { + continue + } + + switch ss[0] { + case "reload": // reload option + if len(ss) > 1 { + period, _ = time.ParseDuration(ss[1]) + } + default: + var k, v string + k = ss[0] + if len(ss) > 1 { + v = ss[1] + } + kvs[k] = v + } + } + + if err := scanner.Err(); err != nil { + return err + } + + au.mux.Lock() + defer au.mux.Unlock() + + au.period = period + au.kvs = kvs + + return nil +} + +// Period returns the reload period. +func (au *LocalAuthenticator) Period() time.Duration { + if au.Stopped() { + return -1 + } + + au.mux.RLock() + defer au.mux.RUnlock() + + return au.period +} + +// Stop stops reloading. +func (au *LocalAuthenticator) Stop() { + select { + case <-au.stopped: + default: + close(au.stopped) + } +} + +// Stopped checks whether the reloader is stopped. +func (au *LocalAuthenticator) Stopped() bool { + select { + case <-au.stopped: + return true + default: + return false + } +} diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..68a1c93 --- /dev/null +++ b/auth_test.go @@ -0,0 +1,191 @@ +package gost + +import ( + "bytes" + "fmt" + "io" + "net/url" + "testing" + "time" +) + +var localAuthenticatorTests = []struct { + clientUser *url.Userinfo + serverUsers []*url.Userinfo + valid bool +}{ + {nil, nil, true}, + {nil, []*url.Userinfo{url.User("admin")}, false}, + {nil, []*url.Userinfo{url.UserPassword("", "123456")}, false}, + {nil, []*url.Userinfo{url.UserPassword("admin", "123456")}, false}, + + {url.User("admin"), nil, true}, + {url.User("admin"), []*url.Userinfo{url.User("admin")}, true}, + {url.User("admin"), []*url.Userinfo{url.User("test")}, false}, + {url.User("admin"), []*url.Userinfo{url.UserPassword("test", "123456")}, false}, + {url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "123456")}, false}, + {url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, true}, + {url.User("admin"), []*url.Userinfo{url.UserPassword("", "123456")}, false}, + + {url.UserPassword("", ""), nil, true}, + {url.UserPassword("", "123456"), nil, true}, + {url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, true}, + {url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("admin", "")}, false}, + {url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, false}, + + {url.UserPassword("admin", "123456"), nil, true}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, true}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("test")}, false}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "")}, true}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, false}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123")}, false}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("test", "123456")}, false}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, true}, + + {url.UserPassword("admin", "123456"), []*url.Userinfo{ + url.UserPassword("test", "123"), + url.UserPassword("admin", "123456"), + }, true}, +} + +func TestLocalAuthenticator(t *testing.T) { + for i, tc := range localAuthenticatorTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + au := NewLocalAuthenticator(nil) + for _, u := range tc.serverUsers { + if u != nil { + p, _ := u.Password() + au.Add(u.Username(), p) + } + } + + var u, p string + if tc.clientUser != nil { + u = tc.clientUser.Username() + p, _ = tc.clientUser.Password() + } + if au.Authenticate(u, p) != tc.valid { + t.Error("authenticate result should be", tc.valid) + } + }) + } +} + +var localAuthenticatorReloadTests = []struct { + r io.Reader + period time.Duration + kvs map[string]string + stopped bool +}{ + { + r: nil, + period: 0, + kvs: nil, + }, + { + r: bytes.NewBufferString(""), + period: 0, + }, + { + r: bytes.NewBufferString("reload 10s"), + period: 10 * time.Second, + }, + { + r: bytes.NewBufferString("# reload 10s\n"), + }, + { + r: bytes.NewBufferString("reload 10s\n#admin"), + period: 10 * time.Second, + }, + { + r: bytes.NewBufferString("reload 10s\nadmin"), + period: 10 * time.Second, + kvs: map[string]string{ + "admin": "", + }, + }, + { + r: bytes.NewBufferString("# reload 10s\nadmin"), + kvs: map[string]string{ + "admin": "", + }, + }, + { + r: bytes.NewBufferString("# reload 10s\nadmin #123456"), + kvs: map[string]string{ + "admin": "#123456", + }, + stopped: true, + }, + { + r: bytes.NewBufferString("admin \t #123456\n\n\ntest \t 123456"), + kvs: map[string]string{ + "admin": "#123456", + "test": "123456", + }, + stopped: true, + }, + { + r: bytes.NewBufferString(` + $test.admin$ $123456$ + @test.admin@ @123456@ + test.admin# #123456# + test.admin\admin 123456 + `), + kvs: map[string]string{ + "$test.admin$": "$123456$", + "@test.admin@": "@123456@", + "test.admin#": "#123456#", + "test.admin\\admin": "123456", + }, + stopped: true, + }, +} + +func TestLocalAuthenticatorReload(t *testing.T) { + isEquals := func(a, b map[string]string) bool { + if len(a) == 0 && len(b) == 0 { + return true + } + if len(a) != len(b) { + return false + } + + for k, v := range a { + if b[k] != v { + return false + } + } + return true + } + for i, tc := range localAuthenticatorReloadTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + au := NewLocalAuthenticator(nil) + + if err := au.Reload(tc.r); err != nil { + t.Error(err) + } + if au.Period() != tc.period { + t.Errorf("#%d test failed: period value should be %v, got %v", + i, tc.period, au.Period()) + } + if !isEquals(au.kvs, tc.kvs) { + t.Errorf("#%d test failed: %v, %s", i, au.kvs, tc.kvs) + } + + if tc.stopped { + au.Stop() + if au.Period() >= 0 { + t.Errorf("period of the stopped reloader should be minus value") + } + au.Stop() + } + if au.Stopped() != tc.stopped { + t.Errorf("#%d test failed: stopped value should be %v, got %v", + i, tc.stopped, au.Stopped()) + } + }) + } +} diff --git a/bypass.go b/bypass.go new file mode 100644 index 0000000..28ca8c8 --- /dev/null +++ b/bypass.go @@ -0,0 +1,298 @@ +package gost + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net" + "strconv" + "strings" + "sync" + "time" + + glob "github.com/gobwas/glob" +) + +// Matcher is a generic pattern matcher, +// it gives the match result of the given pattern for specific v. +type Matcher interface { + Match(v string) bool + String() string +} + +// NewMatcher creates a Matcher for the given pattern. +// The acutal Matcher depends on the pattern: +// IP Matcher if pattern is a valid IP address. +// CIDR Matcher if pattern is a valid CIDR address. +// Domain Matcher if both of the above are not. +func NewMatcher(pattern string) Matcher { + if pattern == "" { + return nil + } + if ip := net.ParseIP(pattern); ip != nil { + return IPMatcher(ip) + } + if _, inet, err := net.ParseCIDR(pattern); err == nil { + return CIDRMatcher(inet) + } + return DomainMatcher(pattern) +} + +type ipMatcher struct { + ip net.IP +} + +// IPMatcher creates a Matcher for a specific IP address. +func IPMatcher(ip net.IP) Matcher { + return &ipMatcher{ + ip: ip, + } +} + +func (m *ipMatcher) Match(ip string) bool { + if m == nil { + return false + } + return m.ip.Equal(net.ParseIP(ip)) +} + +func (m *ipMatcher) String() string { + return "ip " + m.ip.String() +} + +type cidrMatcher struct { + ipNet *net.IPNet +} + +// CIDRMatcher creates a Matcher for a specific CIDR notation IP address. +func CIDRMatcher(inet *net.IPNet) Matcher { + return &cidrMatcher{ + ipNet: inet, + } +} + +func (m *cidrMatcher) Match(ip string) bool { + if m == nil || m.ipNet == nil { + return false + } + return m.ipNet.Contains(net.ParseIP(ip)) +} + +func (m *cidrMatcher) String() string { + return "cidr " + m.ipNet.String() +} + +type domainMatcher struct { + pattern string + glob glob.Glob +} + +// DomainMatcher creates a Matcher for a specific domain pattern, +// the pattern can be a plain domain such as 'example.com', +// a wildcard such as '*.exmaple.com' or a special wildcard '.example.com'. +func DomainMatcher(pattern string) Matcher { + p := pattern + if strings.HasPrefix(pattern, ".") { + p = pattern[1:] // trim the prefix '.' + pattern = "*" + p + } + return &domainMatcher{ + pattern: p, + glob: glob.MustCompile(pattern), + } +} + +func (m *domainMatcher) Match(domain string) bool { + if m == nil || m.glob == nil { + return false + } + + if domain == m.pattern { + return true + } + return m.glob.Match(domain) +} + +func (m *domainMatcher) String() string { + return "domain " + m.pattern +} + +// Bypass is a filter for address (IP or domain). +// It contains a list of matchers. +type Bypass struct { + matchers []Matcher + period time.Duration // the period for live reloading + reversed bool + stopped chan struct{} + mux sync.RWMutex +} + +// NewBypass creates and initializes a new Bypass using matchers as its match rules. +// The rules will be reversed if the reversed is true. +func NewBypass(reversed bool, matchers ...Matcher) *Bypass { + return &Bypass{ + matchers: matchers, + reversed: reversed, + stopped: make(chan struct{}), + } +} + +// NewBypassPatterns creates and initializes a new Bypass using matcher patterns as its match rules. +// The rules will be reversed if the reverse is true. +func NewBypassPatterns(reversed bool, patterns ...string) *Bypass { + var matchers []Matcher + for _, pattern := range patterns { + if m := NewMatcher(pattern); m != nil { + matchers = append(matchers, m) + } + } + bp := NewBypass(reversed) + bp.AddMatchers(matchers...) + return bp +} + +// Contains reports whether the bypass includes addr. +func (bp *Bypass) Contains(addr string) bool { + if bp == nil || addr == "" { + return false + } + + // try to strip the port + if host, port, _ := net.SplitHostPort(addr); host != "" && port != "" { + if p, _ := strconv.Atoi(port); p > 0 { // port is valid + addr = host + } + } + + bp.mux.RLock() + defer bp.mux.RUnlock() + + if len(bp.matchers) == 0 { + return false + } + + var matched bool + for _, matcher := range bp.matchers { + if matcher == nil { + continue + } + if matcher.Match(addr) { + matched = true + break + } + } + return !bp.reversed && matched || + bp.reversed && !matched +} + +// AddMatchers appends matchers to the bypass matcher list. +func (bp *Bypass) AddMatchers(matchers ...Matcher) { + bp.mux.Lock() + defer bp.mux.Unlock() + + bp.matchers = append(bp.matchers, matchers...) +} + +// Matchers return the bypass matcher list. +func (bp *Bypass) Matchers() []Matcher { + bp.mux.RLock() + defer bp.mux.RUnlock() + + return bp.matchers +} + +// Reversed reports whether the rules of the bypass are reversed. +func (bp *Bypass) Reversed() bool { + bp.mux.RLock() + defer bp.mux.RUnlock() + + return bp.reversed +} + +// Reload parses config from r, then live reloads the bypass. +func (bp *Bypass) Reload(r io.Reader) error { + var matchers []Matcher + var period time.Duration + var reversed bool + + if r == nil || bp.Stopped() { + return nil + } + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + ss := splitLine(line) + if len(ss) == 0 { + continue + } + switch ss[0] { + case "reload": // reload option + if len(ss) > 1 { + period, _ = time.ParseDuration(ss[1]) + } + case "reverse": // reverse option + if len(ss) > 1 { + reversed, _ = strconv.ParseBool(ss[1]) + } + default: + matchers = append(matchers, NewMatcher(ss[0])) + } + } + + if err := scanner.Err(); err != nil { + return err + } + + bp.mux.Lock() + defer bp.mux.Unlock() + + bp.matchers = matchers + bp.period = period + bp.reversed = reversed + + return nil +} + +// Period returns the reload period. +func (bp *Bypass) Period() time.Duration { + if bp.Stopped() { + return -1 + } + + bp.mux.RLock() + defer bp.mux.RUnlock() + + return bp.period +} + +// Stop stops reloading. +func (bp *Bypass) Stop() { + select { + case <-bp.stopped: + default: + close(bp.stopped) + } +} + +// Stopped checks whether the reloader is stopped. +func (bp *Bypass) Stopped() bool { + select { + case <-bp.stopped: + return true + default: + return false + } +} + +func (bp *Bypass) String() string { + b := &bytes.Buffer{} + fmt.Fprintf(b, "reversed: %v\n", bp.Reversed()) + fmt.Fprintf(b, "reload: %v\n", bp.Period()) + for _, m := range bp.Matchers() { + b.WriteString(m.String()) + b.WriteByte('\n') + } + return b.String() +} diff --git a/bypass_test.go b/bypass_test.go new file mode 100644 index 0000000..d895121 --- /dev/null +++ b/bypass_test.go @@ -0,0 +1,303 @@ +package gost + +import ( + "bytes" + "fmt" + "io" + "testing" + "time" +) + +var bypassContainTests = []struct { + patterns []string + reversed bool + addr string + bypassed bool +}{ + // empty pattern + {[]string{""}, false, "", false}, + {[]string{""}, false, "192.168.1.1", false}, + {[]string{""}, true, "", false}, + {[]string{""}, true, "192.168.1.1", false}, + + // IP address + {[]string{"192.168.1.1"}, false, "192.168.1.1", true}, + {[]string{"192.168.1.1"}, true, "192.168.1.1", false}, + {[]string{"192.168.1.1"}, false, "192.168.1.2", false}, + {[]string{"192.168.1.1"}, true, "192.168.1.2", true}, + {[]string{"0.0.0.0"}, false, "0.0.0.0", true}, + {[]string{"0.0.0.0"}, true, "0.0.0.0", false}, + + // CIDR address + {[]string{"192.168.1.0/0"}, false, "1.2.3.4", true}, + {[]string{"192.168.1.0/0"}, true, "1.2.3.4", false}, + {[]string{"192.168.1.0/8"}, false, "192.1.0.255", true}, + {[]string{"192.168.1.0/8"}, true, "192.1.0.255", false}, + {[]string{"192.168.1.0/8"}, false, "191.1.0.255", false}, + {[]string{"192.168.1.0/8"}, true, "191.1.0.255", true}, + {[]string{"192.168.1.0/16"}, false, "192.168.0.255", true}, + {[]string{"192.168.1.0/16"}, true, "192.168.0.255", false}, + {[]string{"192.168.1.0/16"}, false, "192.0.1.255", false}, + {[]string{"192.168.1.0/16"}, true, "192.0.0.255", true}, + {[]string{"192.168.1.0/24"}, false, "192.168.1.255", true}, + {[]string{"192.168.1.0/24"}, true, "192.168.1.255", false}, + {[]string{"192.168.1.0/24"}, false, "192.168.0.255", false}, + {[]string{"192.168.1.0/24"}, true, "192.168.0.255", true}, + {[]string{"192.168.1.1/32"}, false, "192.168.1.1", true}, + {[]string{"192.168.1.1/32"}, true, "192.168.1.1", false}, + {[]string{"192.168.1.1/32"}, false, "192.168.1.2", false}, + {[]string{"192.168.1.1/32"}, true, "192.168.1.2", true}, + + // plain domain + {[]string{"www.example.com"}, false, "www.example.com", true}, + {[]string{"www.example.com"}, true, "www.example.com", false}, + {[]string{"http://www.example.com"}, false, "http://www.example.com", true}, + {[]string{"http://www.example.com"}, true, "http://www.example.com", false}, + {[]string{"http://www.example.com"}, false, "http://example.com", false}, + {[]string{"http://www.example.com"}, true, "http://example.com", true}, + {[]string{"www.example.com"}, false, "example.com", false}, + {[]string{"www.example.com"}, true, "example.com", true}, + + // host:port + {[]string{"192.168.1.1"}, false, "192.168.1.1:80", true}, + {[]string{"192.168.1.1"}, true, "192.168.1.1:80", false}, + {[]string{"192.168.1.1:80"}, false, "192.168.1.1", false}, + {[]string{"192.168.1.1:80"}, true, "192.168.1.1", true}, + {[]string{"192.168.1.1:80"}, false, "192.168.1.1:80", false}, + {[]string{"192.168.1.1:80"}, true, "192.168.1.1:80", true}, + {[]string{"192.168.1.1:80"}, false, "192.168.1.1:8080", false}, + {[]string{"192.168.1.1:80"}, true, "192.168.1.1:8080", true}, + + {[]string{"example.com"}, false, "example.com:80", true}, + {[]string{"example.com"}, true, "example.com:80", false}, + {[]string{"example.com:80"}, false, "example.com", false}, + {[]string{"example.com:80"}, true, "example.com", true}, + {[]string{"example.com:80"}, false, "example.com:80", false}, + {[]string{"example.com:80"}, true, "example.com:80", true}, + {[]string{"example.com:80"}, false, "example.com:8080", false}, + {[]string{"example.com:80"}, true, "example.com:8080", true}, + + // domain wildcard + + {[]string{"*"}, false, "", false}, + {[]string{"*"}, false, "192.168.1.1", true}, + {[]string{"*"}, false, "192.168.0.0/16", true}, + {[]string{"*"}, false, "http://example.com", true}, + {[]string{"*"}, false, "example.com:80", true}, + {[]string{"*"}, true, "", false}, + {[]string{"*"}, true, "192.168.1.1", false}, + {[]string{"*"}, true, "192.168.0.0/16", false}, + {[]string{"*"}, true, "http://example.com", false}, + {[]string{"*"}, true, "example.com:80", false}, + + // sub-domain + {[]string{"*.example.com"}, false, "example.com", false}, + {[]string{"*.example.com"}, false, "http://example.com", false}, + {[]string{"*.example.com"}, false, "www.example.com", true}, + {[]string{"*.example.com"}, false, "http://www.example.com", true}, + {[]string{"*.example.com"}, false, "abc.def.example.com", true}, + + {[]string{"*.*.example.com"}, false, "example.com", false}, + {[]string{"*.*.example.com"}, false, "www.example.com", false}, + {[]string{"*.*.example.com"}, false, "abc.def.example.com", true}, + {[]string{"*.*.example.com"}, false, "abc.def.ghi.example.com", true}, + + {[]string{"**.example.com"}, false, "example.com", false}, + {[]string{"**.example.com"}, false, "www.example.com", true}, + {[]string{"**.example.com"}, false, "abc.def.ghi.example.com", true}, + + // prefix wildcard + {[]string{"*example.com"}, false, "example.com", true}, + {[]string{"*example.com"}, false, "www.example.com", true}, + {[]string{"*example.com"}, false, "abc.defexample.com", true}, + {[]string{"*example.com"}, false, "abc.def-example.com", true}, + {[]string{"*example.com"}, false, "abc.def.example.com", true}, + {[]string{"*example.com"}, false, "http://www.example.com", true}, + {[]string{"*example.com"}, false, "e-xample.com", false}, + + {[]string{"http://*.example.com"}, false, "example.com", false}, + {[]string{"http://*.example.com"}, false, "http://example.com", false}, + {[]string{"http://*.example.com"}, false, "http://www.example.com", true}, + {[]string{"http://*.example.com"}, false, "https://www.example.com", false}, + {[]string{"http://*.example.com"}, false, "http://abc.def.example.com", true}, + + {[]string{"www.*.com"}, false, "www.example.com", true}, + {[]string{"www.*.com"}, false, "www.abc.def.com", true}, + + {[]string{"www.*.*.com"}, false, "www.example.com", false}, + {[]string{"www.*.*.com"}, false, "www.abc.def.com", true}, + {[]string{"www.*.*.com"}, false, "www.abc.def.ghi.com", true}, + + {[]string{"www.*example*.com"}, false, "www.example.com", true}, + {[]string{"www.*example*.com"}, false, "www.abc.example.def.com", true}, + {[]string{"www.*example*.com"}, false, "www.e-xample.com", false}, + + {[]string{"www.example.*"}, false, "www.example.com", true}, + {[]string{"www.example.*"}, false, "www.example.io", true}, + {[]string{"www.example.*"}, false, "www.example.com.cn", true}, + + {[]string{".example.com"}, false, "www.example.com", true}, + {[]string{".example.com"}, false, "example.com", true}, + {[]string{".example.com"}, false, "www.example.com.cn", false}, + + {[]string{"example.com*"}, false, "example.com", true}, + {[]string{"example.com:*"}, false, "example.com", false}, + {[]string{"example.com:*"}, false, "example.com:80", false}, + {[]string{"example.com:*"}, false, "example.com:8080", false}, + {[]string{"example.com:*"}, false, "example.com:http", true}, + {[]string{"example.com:*"}, false, "http://example.com:80", false}, + + {[]string{"*example.com*"}, false, "example.com:80", true}, + {[]string{"*example.com:*"}, false, "example.com:80", false}, + + {[]string{".example.com:*"}, false, "www.example.com", false}, + {[]string{".example.com:*"}, false, "http://www.example.com", false}, + {[]string{".example.com:*"}, false, "example.com:80", false}, + {[]string{".example.com:*"}, false, "www.example.com:8080", false}, + {[]string{".example.com:*"}, false, "http://www.example.com:80", true}, +} + +func TestBypassContains(t *testing.T) { + for i, tc := range bypassContainTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + bp := NewBypassPatterns(tc.reversed, tc.patterns...) + if bp.Contains(tc.addr) != tc.bypassed { + t.Errorf("#%d test failed: %v, %s", i, tc.patterns, tc.addr) + } + }) + } +} + +var bypassReloadTests = []struct { + r io.Reader + + reversed bool + period time.Duration + + addr string + bypassed bool + stopped bool +}{ + { + r: nil, + reversed: false, + period: 0, + addr: "192.168.1.1", + bypassed: false, + stopped: false, + }, + { + r: bytes.NewBufferString(""), + reversed: false, + period: 0, + addr: "192.168.1.1", + bypassed: false, + stopped: false, + }, + { + r: bytes.NewBufferString("reverse true\nreload 10s"), + reversed: true, + period: 10 * time.Second, + addr: "192.168.1.1", + bypassed: false, + stopped: false, + }, + { + r: bytes.NewBufferString("reverse false\nreload 10s\n192.168.1.1"), + reversed: false, + period: 10 * time.Second, + addr: "192.168.1.1", + bypassed: true, + stopped: false, + }, + { + r: bytes.NewBufferString("#reverse true\n#reload 10s\n192.168.0.0/16"), + reversed: false, + period: 0, + addr: "192.168.10.2", + bypassed: true, + stopped: true, + }, + { + r: bytes.NewBufferString("#reverse true\n#reload 10s\n192.168.1.0/24 #comment"), + reversed: false, + period: 0, + addr: "192.168.10.2", + bypassed: false, + stopped: true, + }, + { + r: bytes.NewBufferString("reverse false\nreload 10s\n192.168.1.1\n#example.com"), + reversed: false, + period: 10 * time.Second, + addr: "example.com", + bypassed: false, + stopped: false, + }, + { + r: bytes.NewBufferString("#reverse true\n#reload 10s\n192.168.1.1\n#example.com"), + reversed: false, + period: 0, + addr: "192.168.1.1", + bypassed: true, + stopped: true, + }, + { + r: bytes.NewBufferString("#reverse true\n#reload 10s\nexample.com #comment"), + reversed: false, + period: 0, + addr: "example.com", + bypassed: true, + stopped: true, + }, + { + r: bytes.NewBufferString("#reverse true\n#reload 10s\n.example.com"), + reversed: false, + period: 0, + addr: "example.com", + bypassed: true, + stopped: true, + }, + { + r: bytes.NewBufferString("#reverse true\n#reload 10s\n*.example.com"), + reversed: false, + period: 0, + addr: "example.com", + bypassed: false, + stopped: true, + }, +} + +func TestByapssReload(t *testing.T) { + for i, tc := range bypassReloadTests { + bp := NewBypass(false) + if err := bp.Reload(tc.r); err != nil { + t.Error(err) + } + t.Log(bp.String()) + + if bp.Reversed() != tc.reversed { + t.Errorf("#%d test failed: reversed value should be %v, got %v", + i, tc.reversed, bp.reversed) + } + if bp.Period() != tc.period { + t.Errorf("#%d test failed: period value should be %v, got %v", + i, tc.period, bp.Period()) + } + if bp.Contains(tc.addr) != tc.bypassed { + t.Errorf("#%d test failed: %v, %s", i, bp.reversed, tc.addr) + } + if tc.stopped { + bp.Stop() + if bp.Period() >= 0 { + t.Errorf("period of the stopped reloader should be minus value") + } + bp.Stop() + } + if bp.Stopped() != tc.stopped { + t.Errorf("#%d test failed: stopped value should be %v, got %v", + i, tc.stopped, bp.Stopped()) + } + } +} diff --git a/chain.go b/chain.go new file mode 100644 index 0000000..cbda871 --- /dev/null +++ b/chain.go @@ -0,0 +1,380 @@ +package gost + +import ( + "context" + "errors" + "net" + "syscall" + "time" + + "github.com/go-log/log" +) + +var ( + // ErrEmptyChain is an error that implies the chain is empty. + ErrEmptyChain = errors.New("empty chain") +) + +// Chain is a proxy chain that holds a list of proxy node groups. +type Chain struct { + isRoute bool + Retries int + Mark int + nodeGroups []*NodeGroup + route []Node // nodes in the selected route +} + +// NewChain creates a proxy chain with a list of proxy nodes. +// It creates the node groups automatically, one group per node. +func NewChain(nodes ...Node) *Chain { + chain := &Chain{} + for _, node := range nodes { + chain.nodeGroups = append(chain.nodeGroups, NewNodeGroup(node)) + } + return chain +} + +// newRoute creates a chain route. +// a chain route is the final route after node selection. +func newRoute(nodes ...Node) *Chain { + chain := NewChain(nodes...) + chain.isRoute = true + return chain +} + +// Nodes returns the proxy nodes that the chain holds. +// The first node in each group will be returned. +func (c *Chain) Nodes() (nodes []Node) { + for _, group := range c.nodeGroups { + if ns := group.Nodes(); len(ns) > 0 { + nodes = append(nodes, ns[0]) + } + } + return +} + +// NodeGroups returns the list of node group. +func (c *Chain) NodeGroups() []*NodeGroup { + return c.nodeGroups +} + +// LastNode returns the last node of the node list. +// If the chain is empty, an empty node will be returned. +// If the last node is a node group, the first node in the group will be returned. +func (c *Chain) LastNode() Node { + if c.IsEmpty() { + return Node{} + } + group := c.nodeGroups[len(c.nodeGroups)-1] + return group.GetNode(0) +} + +// LastNodeGroup returns the last group of the group list. +func (c *Chain) LastNodeGroup() *NodeGroup { + if c.IsEmpty() { + return nil + } + return c.nodeGroups[len(c.nodeGroups)-1] +} + +// AddNode appends the node(s) to the chain. +func (c *Chain) AddNode(nodes ...Node) { + if c == nil { + return + } + for _, node := range nodes { + c.nodeGroups = append(c.nodeGroups, NewNodeGroup(node)) + } +} + +// AddNodeGroup appends the group(s) to the chain. +func (c *Chain) AddNodeGroup(groups ...*NodeGroup) { + if c == nil { + return + } + for _, group := range groups { + c.nodeGroups = append(c.nodeGroups, group) + } +} + +// IsEmpty checks if the chain is empty. +// An empty chain means that there is no proxy node or node group in the chain. +func (c *Chain) IsEmpty() bool { + return c == nil || len(c.nodeGroups) == 0 +} + +// Dial connects to the target TCP address addr through the chain. +// Deprecated: use DialContext instead. +func (c *Chain) Dial(address string, opts ...ChainOption) (conn net.Conn, err error) { + return c.DialContext(context.Background(), "tcp", address, opts...) +} + +// DialContext connects to the address on the named network using the provided context. +func (c *Chain) DialContext(ctx context.Context, network, address string, opts ...ChainOption) (conn net.Conn, err error) { + options := &ChainOptions{} + for _, opt := range opts { + opt(options) + } + + retries := 1 + if c != nil && c.Retries > 0 { + retries = c.Retries + } + if options.Retries > 0 { + retries = options.Retries + } + + for i := 0; i < retries; i++ { + conn, err = c.dialWithOptions(ctx, network, address, options) + if err == nil { + break + } + } + return +} + +func (c *Chain) dialWithOptions(ctx context.Context, network, address string, options *ChainOptions) (net.Conn, error) { + if options == nil { + options = &ChainOptions{} + } + route, err := c.selectRouteFor(address) + if err != nil { + return nil, err + } + + ipAddr := address + if address != "" { + ipAddr = c.resolve(address, options.Resolver, options.Hosts) + } + + timeout := options.Timeout + if timeout <= 0 { + timeout = DialTimeout + } + + var controlFunction func(_ string, _ string, c syscall.RawConn) error = nil + if c != nil && c.Mark > 0 { + controlFunction = func(_, _ string, cc syscall.RawConn) error { + return cc.Control(func(fd uintptr) { + ex := setSocketMark(int(fd), c.Mark) + if ex != nil { + log.Logf("net dialer set mark %d error: %s", c.Mark, ex) + } else { + // log.Logf("net dialer set mark %d success", options.Mark) + } + }) + } + } + + if route.IsEmpty() { + switch network { + case "udp", "udp4", "udp6": + if address == "" { + return net.ListenUDP(network, nil) + } + default: + } + d := &net.Dialer{ + Timeout: timeout, + Control: controlFunction, + // LocalAddr: laddr, // TODO: optional local address + } + return d.DialContext(ctx, network, ipAddr) + } + + conn, err := route.getConn(ctx) + if err != nil { + return nil, err + } + + cOpts := append([]ConnectOption{AddrConnectOption(address)}, route.LastNode().ConnectOptions...) + cc, err := route.LastNode().Client.ConnectContext(ctx, conn, network, ipAddr, cOpts...) + if err != nil { + conn.Close() + return nil, err + } + return cc, nil +} + +func (*Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return addr + } + + if ip := hosts.Lookup(host); ip != nil { + return net.JoinHostPort(ip.String(), port) + } + if resolver != nil { + ips, err := resolver.Resolve(host) + if err != nil { + log.Logf("[resolver] %s: %v", host, err) + } + if len(ips) > 0 { + return net.JoinHostPort(ips[0].String(), port) + } + } + return addr +} + +// Conn obtains a handshaked connection to the last node of the chain. +func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) { + options := &ChainOptions{} + for _, opt := range opts { + opt(options) + } + + ctx := context.Background() + + retries := 1 + if c != nil && c.Retries > 0 { + retries = c.Retries + } + if options.Retries > 0 { + retries = options.Retries + } + + for i := 0; i < retries; i++ { + var route *Chain + route, err = c.selectRoute() + if err != nil { + continue + } + conn, err = route.getConn(ctx) + if err == nil { + break + } + } + return +} + +// getConn obtains a connection to the last node of the chain. +func (c *Chain) getConn(ctx context.Context) (conn net.Conn, err error) { + if c.IsEmpty() { + err = ErrEmptyChain + return + } + nodes := c.Nodes() + node := nodes[0] + + cc, err := node.Client.Dial(node.Addr, node.DialOptions...) + if err != nil { + node.MarkDead() + return + } + + cn, err := node.Client.Handshake(cc, node.HandshakeOptions...) + if err != nil { + cc.Close() + node.MarkDead() + return + } + node.ResetDead() + + preNode := node + for _, node := range nodes[1:] { + var cc net.Conn + cc, err = preNode.Client.ConnectContext(ctx, cn, "tcp", node.Addr, preNode.ConnectOptions...) + if err != nil { + cn.Close() + node.MarkDead() + return + } + cc, err = node.Client.Handshake(cc, node.HandshakeOptions...) + if err != nil { + cn.Close() + node.MarkDead() + return + } + node.ResetDead() + + cn = cc + preNode = node + } + + conn = cn + return +} + +func (c *Chain) selectRoute() (route *Chain, err error) { + return c.selectRouteFor("") +} + +// selectRouteFor selects route with bypass testing. +func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) { + if c.IsEmpty() { + return newRoute(), nil + } + if c.isRoute { + return c, nil + } + + route = newRoute() + var nl []Node + + for _, group := range c.nodeGroups { + var node Node + node, err = group.Next() + if err != nil { + return + } + + if node.Bypass.Contains(addr) { + break + } + + if node.Client.Transporter.Multiplex() { + node.DialOptions = append(node.DialOptions, + ChainDialOption(route), + ) + route = newRoute() // cutoff the chain for multiplex node. + } + + route.AddNode(node) + nl = append(nl, node) + } + + route.route = nl + + return +} + +// ChainOptions holds options for Chain. +type ChainOptions struct { + Retries int + Timeout time.Duration + Hosts *Hosts + Resolver Resolver + Mark int +} + +// ChainOption allows a common way to set chain options. +type ChainOption func(opts *ChainOptions) + +// RetryChainOption specifies the times of retry used by Chain.Dial. +func RetryChainOption(retries int) ChainOption { + return func(opts *ChainOptions) { + opts.Retries = retries + } +} + +// TimeoutChainOption specifies the timeout used by Chain.Dial. +func TimeoutChainOption(timeout time.Duration) ChainOption { + return func(opts *ChainOptions) { + opts.Timeout = timeout + } +} + +// HostsChainOption specifies the hosts used by Chain.Dial. +func HostsChainOption(hosts *Hosts) ChainOption { + return func(opts *ChainOptions) { + opts.Hosts = hosts + } +} + +// ResolverChainOption specifies the Resolver used by Chain.Dial. +func ResolverChainOption(resolver Resolver) ChainOption { + return func(opts *ChainOptions) { + opts.Resolver = resolver + } +} diff --git a/client.go b/client.go new file mode 100644 index 0000000..32c0994 --- /dev/null +++ b/client.go @@ -0,0 +1,285 @@ +package gost + +import ( + "context" + "crypto/tls" + "net" + "net/url" + "time" + + "github.com/go-gost/gosocks5" +) + +// Client is a proxy client. +// A client is divided into two layers: connector and transporter. +// Connector is responsible for connecting to the destination address through this proxy. +// Transporter performs a handshake with this proxy. +type Client struct { + Connector + Transporter +} + +// DefaultClient is a standard HTTP proxy client. +var DefaultClient = &Client{Connector: HTTPConnector(nil), Transporter: TCPTransporter()} + +// Dial connects to the address addr via the DefaultClient. +func Dial(addr string, options ...DialOption) (net.Conn, error) { + return DefaultClient.Dial(addr, options...) +} + +// Handshake performs a handshake via the DefaultClient. +func Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + return DefaultClient.Handshake(conn, options...) +} + +// Connect connects to the address addr via the DefaultClient. +func Connect(conn net.Conn, addr string) (net.Conn, error) { + return DefaultClient.Connect(conn, addr) +} + +// Connector is responsible for connecting to the destination address. +type Connector interface { + // Deprecated: use ConnectContext instead. + Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) + ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) +} + +type autoConnector struct { + User *url.Userinfo +} + +// AutoConnector is a Connector. +func AutoConnector(user *url.Userinfo) Connector { + return &autoConnector{ + User: user, + } +} + +func (c *autoConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + +func (c *autoConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + var cnr Connector + switch network { + case "tcp", "tcp4", "tcp6": + cnr = &httpConnector{User: c.User} + default: + cnr = &socks5UDPTunConnector{User: c.User} + } + + return cnr.ConnectContext(ctx, conn, network, address, options...) +} + +// Transporter is responsible for handshaking with the proxy server. +type Transporter interface { + Dial(addr string, options ...DialOption) (net.Conn, error) + Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) + // Indicate that the Transporter supports multiplex + Multiplex() bool +} + +// DialOptions describes the options for Transporter.Dial. +type DialOptions struct { + Timeout time.Duration + Chain *Chain + Host string + HeaderConfig map[string]string +} + +// DialOption allows a common way to set DialOptions. +type DialOption func(opts *DialOptions) + +// TimeoutDialOption specifies the timeout used by Transporter.Dial +func TimeoutDialOption(timeout time.Duration) DialOption { + return func(opts *DialOptions) { + opts.Timeout = timeout + } +} + +// ChainDialOption specifies a chain used by Transporter.Dial +func ChainDialOption(chain *Chain) DialOption { + return func(opts *DialOptions) { + opts.Chain = chain + } +} + +// HostDialOption specifies the host used by Transporter.Dial +func HostDialOption(host string) DialOption { + return func(opts *DialOptions) { + opts.Host = host + } +} + +// HeaderConfigDialOption specifies the header used by Transporter.Dial +func HeaderConfigDialOption(HeaderConfig map[string]string) DialOption { + return func(opts *DialOptions) { + opts.HeaderConfig = HeaderConfig + } +} + +// HandshakeOptions describes the options for handshake. +type HandshakeOptions struct { + Addr string + Host string + User *url.Userinfo + Timeout time.Duration + Interval time.Duration + Retry int + TLSConfig *tls.Config + WSOptions *WSOptions + KCPConfig *KCPConfig + QUICConfig *QUICConfig + SSHConfig *SSHConfig +} + +// HandshakeOption allows a common way to set HandshakeOptions. +type HandshakeOption func(opts *HandshakeOptions) + +// AddrHandshakeOption specifies the server address +func AddrHandshakeOption(addr string) HandshakeOption { + return func(opts *HandshakeOptions) { + opts.Addr = addr + } +} + +// HostHandshakeOption specifies the hostname +func HostHandshakeOption(host string) HandshakeOption { + return func(opts *HandshakeOptions) { + opts.Host = host + } +} + +// UserHandshakeOption specifies the user used by Transporter.Handshake +func UserHandshakeOption(user *url.Userinfo) HandshakeOption { + return func(opts *HandshakeOptions) { + opts.User = user + } +} + +// TimeoutHandshakeOption specifies the timeout used by Transporter.Handshake +func TimeoutHandshakeOption(timeout time.Duration) HandshakeOption { + return func(opts *HandshakeOptions) { + opts.Timeout = timeout + } +} + +// IntervalHandshakeOption specifies the interval time used by Transporter.Handshake +func IntervalHandshakeOption(interval time.Duration) HandshakeOption { + return func(opts *HandshakeOptions) { + opts.Interval = interval + } +} + +// RetryHandshakeOption specifies the times of retry used by Transporter.Handshake +func RetryHandshakeOption(retry int) HandshakeOption { + return func(opts *HandshakeOptions) { + opts.Retry = retry + } +} + +// TLSConfigHandshakeOption specifies the TLS config used by Transporter.Handshake +func TLSConfigHandshakeOption(config *tls.Config) HandshakeOption { + return func(opts *HandshakeOptions) { + opts.TLSConfig = config + } +} + +// WSOptionsHandshakeOption specifies the websocket options used by websocket handshake +func WSOptionsHandshakeOption(options *WSOptions) HandshakeOption { + return func(opts *HandshakeOptions) { + opts.WSOptions = options + } +} + +// KCPConfigHandshakeOption specifies the KCP config used by KCP handshake +func KCPConfigHandshakeOption(config *KCPConfig) HandshakeOption { + return func(opts *HandshakeOptions) { + opts.KCPConfig = config + } +} + +// QUICConfigHandshakeOption specifies the QUIC config used by QUIC handshake +func QUICConfigHandshakeOption(config *QUICConfig) HandshakeOption { + return func(opts *HandshakeOptions) { + opts.QUICConfig = config + } +} + +// SSHConfigHandshakeOption specifies the ssh config used by SSH client handshake. +func SSHConfigHandshakeOption(config *SSHConfig) HandshakeOption { + return func(opts *HandshakeOptions) { + opts.SSHConfig = config + } +} + +// ConnectOptions describes the options for Connector.Connect. +type ConnectOptions struct { + Addr string + Timeout time.Duration + User *url.Userinfo + Selector gosocks5.Selector + UserAgent string + NoTLS bool + NoDelay bool + HeaderConfig map[string]string +} + +// ConnectOption allows a common way to set ConnectOptions. +type ConnectOption func(opts *ConnectOptions) + +// AddrConnectOption specifies the corresponding address of the target. +func AddrConnectOption(addr string) ConnectOption { + return func(opts *ConnectOptions) { + opts.Addr = addr + } +} + +// TimeoutConnectOption specifies the timeout for connecting to target. +func TimeoutConnectOption(timeout time.Duration) ConnectOption { + return func(opts *ConnectOptions) { + opts.Timeout = timeout + } +} + +// UserConnectOption specifies the user info for authentication. +func UserConnectOption(user *url.Userinfo) ConnectOption { + return func(opts *ConnectOptions) { + opts.User = user + } +} + +// SelectorConnectOption specifies the SOCKS5 client selector. +func SelectorConnectOption(s gosocks5.Selector) ConnectOption { + return func(opts *ConnectOptions) { + opts.Selector = s + } +} + +// UserAgentConnectOption specifies the HTTP user-agent header. +func UserAgentConnectOption(ua string) ConnectOption { + return func(opts *ConnectOptions) { + opts.UserAgent = ua + } +} + +// NoTLSConnectOption specifies the SOCKS5 method without TLS. +func NoTLSConnectOption(b bool) ConnectOption { + return func(opts *ConnectOptions) { + opts.NoTLS = b + } +} + +// NoDelayConnectOption specifies the NoDelay option for ss.Connect. +func NoDelayConnectOption(b bool) ConnectOption { + return func(opts *ConnectOptions) { + opts.NoDelay = b + } +} + +// HeaderConnectOption specifies the NoDelay option for ss.Connect. +func HeaderConnectOption(HeaderConfig map[string]string) ConnectOption { + return func(opts *ConnectOptions) { + opts.HeaderConfig = HeaderConfig + } +} diff --git a/cmd/gost/cfg.go b/cmd/gost/cfg.go new file mode 100644 index 0000000..b8fda23 --- /dev/null +++ b/cmd/gost/cfg.go @@ -0,0 +1,333 @@ +package main + +import ( + "bufio" + "crypto/tls" + "crypto/x509" + "encoding/json" + "errors" + "io/ioutil" + "net" + "net/url" + "os" + "strings" + + "github.com/ginuerzh/gost" +) + +var ( + routers []router +) + +type baseConfig struct { + route + Routes []route + Debug bool +} + +func parseBaseConfig(s string) (*baseConfig, error) { + file, err := os.Open(s) + if err != nil { + return nil, err + } + defer file.Close() + + if err := json.NewDecoder(file).Decode(baseCfg); err != nil { + return nil, err + } + + return baseCfg, nil +} + +var ( + defaultCertFile = "cert.pem" + defaultKeyFile = "key.pem" +) + +// Load the certificate from cert & key files and optional client CA file, +// will use the default certificate if the provided info are invalid. +func tlsConfig(certFile, keyFile, caFile string) (*tls.Config, error) { + if certFile == "" || keyFile == "" { + certFile, keyFile = defaultCertFile, defaultKeyFile + } + + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + + cfg := &tls.Config{Certificates: []tls.Certificate{cert}} + + if pool, _ := loadCA(caFile); pool != nil { + cfg.ClientCAs = pool + cfg.ClientAuth = tls.RequireAndVerifyClientCert + } + + return cfg, nil +} + +func loadCA(caFile string) (cp *x509.CertPool, err error) { + if caFile == "" { + return + } + cp = x509.NewCertPool() + data, err := ioutil.ReadFile(caFile) + if err != nil { + return nil, err + } + if !cp.AppendCertsFromPEM(data) { + return nil, errors.New("AppendCertsFromPEM failed") + } + return +} + +func parseKCPConfig(configFile string) (*gost.KCPConfig, error) { + if configFile == "" { + return nil, nil + } + file, err := os.Open(configFile) + if err != nil { + return nil, err + } + defer file.Close() + + config := &gost.KCPConfig{} + if err = json.NewDecoder(file).Decode(config); err != nil { + return nil, err + } + return config, nil +} + +func parseUsers(authFile string) (users []*url.Userinfo, err error) { + if authFile == "" { + return + } + + file, err := os.Open(authFile) + if err != nil { + return + } + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + s := strings.SplitN(line, " ", 2) + if len(s) == 1 { + users = append(users, url.User(strings.TrimSpace(s[0]))) + } else if len(s) == 2 { + users = append(users, url.UserPassword(strings.TrimSpace(s[0]), strings.TrimSpace(s[1]))) + } + } + + err = scanner.Err() + return +} + +func parseAuthenticator(s string) (gost.Authenticator, error) { + if s == "" { + return nil, nil + } + f, err := os.Open(s) + if err != nil { + return nil, err + } + defer f.Close() + + au := gost.NewLocalAuthenticator(nil) + au.Reload(f) + + go gost.PeriodReload(au, s) + + return au, nil +} + +func parseIP(s string, port string) (ips []string) { + if s == "" { + return + } + if port == "" { + port = "8080" // default port + } + + file, err := os.Open(s) + if err != nil { + ss := strings.Split(s, ",") + for _, s := range ss { + s = strings.TrimSpace(s) + if s != "" { + // TODO: support IPv6 + if !strings.Contains(s, ":") { + s = s + ":" + port + } + ips = append(ips, s) + } + + } + return + } + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + if !strings.Contains(line, ":") { + line = line + ":" + port + } + ips = append(ips, line) + } + return +} + +func parseBypass(s string) *gost.Bypass { + if s == "" { + return nil + } + var matchers []gost.Matcher + var reversed bool + if strings.HasPrefix(s, "~") { + reversed = true + s = strings.TrimLeft(s, "~") + } + + f, err := os.Open(s) + if err != nil { + for _, s := range strings.Split(s, ",") { + s = strings.TrimSpace(s) + if s == "" { + continue + } + matchers = append(matchers, gost.NewMatcher(s)) + } + return gost.NewBypass(reversed, matchers...) + } + defer f.Close() + + bp := gost.NewBypass(reversed) + bp.Reload(f) + go gost.PeriodReload(bp, s) + + return bp +} + +func parseResolver(cfg string) gost.Resolver { + if cfg == "" { + return nil + } + var nss []gost.NameServer + + f, err := os.Open(cfg) + if err != nil { + for _, s := range strings.Split(cfg, ",") { + s = strings.TrimSpace(s) + if s == "" { + continue + } + if strings.HasPrefix(s, "https") { + p := "https" + u, _ := url.Parse(s) + if u == nil || u.Scheme == "" { + continue + } + if u.Scheme == "https-chain" { + p = u.Scheme + } + ns := gost.NameServer{ + Addr: s, + Protocol: p, + } + nss = append(nss, ns) + continue + } + + ss := strings.Split(s, "/") + if len(ss) == 1 { + ns := gost.NameServer{ + Addr: ss[0], + } + nss = append(nss, ns) + } + if len(ss) == 2 { + ns := gost.NameServer{ + Addr: ss[0], + Protocol: ss[1], + } + nss = append(nss, ns) + } + } + return gost.NewResolver(0, nss...) + } + defer f.Close() + + resolver := gost.NewResolver(0) + resolver.Reload(f) + + go gost.PeriodReload(resolver, cfg) + + return resolver +} + +func parseHosts(s string) *gost.Hosts { + f, err := os.Open(s) + if err != nil { + return nil + } + defer f.Close() + + hosts := gost.NewHosts() + hosts.Reload(f) + + go gost.PeriodReload(hosts, s) + + return hosts +} + +func parseIPRoutes(s string) (routes []gost.IPRoute) { + if s == "" { + return + } + + file, err := os.Open(s) + if err != nil { + ss := strings.Split(s, ",") + for _, s := range ss { + if _, inet, _ := net.ParseCIDR(strings.TrimSpace(s)); inet != nil { + routes = append(routes, gost.IPRoute{Dest: inet}) + } + } + return + } + + defer file.Close() + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.Replace(scanner.Text(), "\t", " ", -1) + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + var route gost.IPRoute + var ss []string + for _, s := range strings.Split(line, " ") { + if s = strings.TrimSpace(s); s != "" { + ss = append(ss, s) + } + } + if len(ss) > 0 && ss[0] != "" { + _, route.Dest, _ = net.ParseCIDR(strings.TrimSpace(ss[0])) + if route.Dest == nil { + continue + } + } + if len(ss) > 1 && ss[1] != "" { + route.Gateway = net.ParseIP(ss[1]) + } + routes = append(routes, route) + } + return routes +} diff --git a/cmd/gost/main.go b/cmd/gost/main.go new file mode 100644 index 0000000..7aea15b --- /dev/null +++ b/cmd/gost/main.go @@ -0,0 +1,122 @@ +package main + +import ( + "crypto/tls" + "errors" + "flag" + "fmt" + "net/http" + "os" + "runtime" + + _ "net/http/pprof" + + "github.com/ginuerzh/gost" + "github.com/go-log/log" +) + +var ( + configureFile string + baseCfg = &baseConfig{} + pprofAddr string + pprofEnabled = os.Getenv("PROFILING") != "" +) + +func init() { + gost.SetLogger(&gost.LogLogger{}) + + var ( + printVersion bool + ) + + flag.Var(&baseCfg.route.ChainNodes, "F", "forward address, can make a forward chain") + flag.Var(&baseCfg.route.ServeNodes, "L", "listen address, can listen on multiple ports (required)") + flag.IntVar(&baseCfg.route.Mark, "M", 0, "Specify out connection mark") + flag.StringVar(&configureFile, "C", "", "configure file") + flag.BoolVar(&baseCfg.Debug, "D", false, "enable debug log") + flag.BoolVar(&printVersion, "V", false, "print version") + if pprofEnabled { + flag.StringVar(&pprofAddr, "P", ":6060", "profiling HTTP server address") + } + flag.Parse() + + if printVersion { + fmt.Fprintf(os.Stdout, "gost %s (%s %s/%s)\n", + gost.Version, runtime.Version(), runtime.GOOS, runtime.GOARCH) + os.Exit(0) + } + + if configureFile != "" { + _, err := parseBaseConfig(configureFile) + if err != nil { + log.Log(err) + os.Exit(1) + } + } + if flag.NFlag() == 0 { + flag.PrintDefaults() + os.Exit(0) + } +} + +func main() { + if pprofEnabled { + go func() { + log.Log("profiling server on", pprofAddr) + log.Log(http.ListenAndServe(pprofAddr, nil)) + }() + } + + // NOTE: as of 2.6, you can use custom cert/key files to initialize the default certificate. + tlsConfig, err := tlsConfig(defaultCertFile, defaultKeyFile, "") + if err != nil { + // generate random self-signed certificate. + cert, err := gost.GenCertificate() + if err != nil { + log.Log(err) + os.Exit(1) + } + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + } else { + log.Log("load TLS certificate files OK") + } + + gost.DefaultTLSConfig = tlsConfig + + if err := start(); err != nil { + log.Log(err) + os.Exit(1) + } + + select {} +} + +func start() error { + gost.Debug = baseCfg.Debug + + var routers []router + rts, err := baseCfg.route.GenRouters() + if err != nil { + return err + } + routers = append(routers, rts...) + + for _, route := range baseCfg.Routes { + rts, err := route.GenRouters() + if err != nil { + return err + } + routers = append(routers, rts...) + } + + if len(routers) == 0 { + return errors.New("invalid config") + } + for i := range routers { + go routers[i].Serve() + } + + return nil +} diff --git a/cmd/gost/peer.go b/cmd/gost/peer.go new file mode 100644 index 0000000..9e1d00f --- /dev/null +++ b/cmd/gost/peer.go @@ -0,0 +1,165 @@ +package main + +import ( + "bufio" + "bytes" + "encoding/json" + "io" + "io/ioutil" + "strconv" + "strings" + "time" + + "github.com/ginuerzh/gost" +) + +type peerConfig struct { + Strategy string `json:"strategy"` + MaxFails int `json:"max_fails"` + FailTimeout time.Duration + period time.Duration // the period for live reloading + Nodes []string `json:"nodes"` + group *gost.NodeGroup + baseNodes []gost.Node + stopped chan struct{} +} + +func newPeerConfig() *peerConfig { + return &peerConfig{ + stopped: make(chan struct{}), + } +} + +func (cfg *peerConfig) Validate() { +} + +func (cfg *peerConfig) Reload(r io.Reader) error { + if cfg.Stopped() { + return nil + } + + if err := cfg.parse(r); err != nil { + return err + } + cfg.Validate() + + group := cfg.group + group.SetSelector( + nil, + gost.WithFilter( + &gost.FailFilter{ + MaxFails: cfg.MaxFails, + FailTimeout: cfg.FailTimeout, + }, + &gost.InvalidFilter{}, + ), + gost.WithStrategy(gost.NewStrategy(cfg.Strategy)), + ) + + gNodes := cfg.baseNodes + nid := len(gNodes) + 1 + for _, s := range cfg.Nodes { + nodes, err := parseChainNode(s) + if err != nil { + return err + } + + for i := range nodes { + nodes[i].ID = nid + nid++ + } + + gNodes = append(gNodes, nodes...) + } + + nodes := group.SetNodes(gNodes...) + for _, node := range nodes[len(cfg.baseNodes):] { + if node.Bypass != nil { + node.Bypass.Stop() // clear the old nodes + } + } + + return nil +} + +func (cfg *peerConfig) parse(r io.Reader) error { + data, err := ioutil.ReadAll(r) + if err != nil { + return err + } + + // compatible with JSON format + if err := json.NewDecoder(bytes.NewReader(data)).Decode(cfg); err == nil { + return nil + } + + split := func(line string) []string { + if line == "" { + return nil + } + if n := strings.IndexByte(line, '#'); n >= 0 { + line = line[:n] + } + line = strings.Replace(line, "\t", " ", -1) + line = strings.TrimSpace(line) + + var ss []string + for _, s := range strings.Split(line, " ") { + if s = strings.TrimSpace(s); s != "" { + ss = append(ss, s) + } + } + return ss + } + + cfg.Nodes = nil + scanner := bufio.NewScanner(bytes.NewReader(data)) + for scanner.Scan() { + line := scanner.Text() + ss := split(line) + if len(ss) < 2 { + continue + } + + switch ss[0] { + case "strategy": + cfg.Strategy = ss[1] + case "max_fails": + cfg.MaxFails, _ = strconv.Atoi(ss[1]) + case "fail_timeout": + cfg.FailTimeout, _ = time.ParseDuration(ss[1]) + case "reload": + cfg.period, _ = time.ParseDuration(ss[1]) + case "peer": + cfg.Nodes = append(cfg.Nodes, ss[1]) + } + } + + return scanner.Err() +} + +func (cfg *peerConfig) Period() time.Duration { + if cfg.Stopped() { + return -1 + } + return cfg.period +} + +// Stop stops reloading. +func (cfg *peerConfig) Stop() { + select { + case <-cfg.stopped: + default: + close(cfg.stopped) + } +} + +// Stopped checks whether the reloader is stopped. +func (cfg *peerConfig) Stopped() bool { + select { + case <-cfg.stopped: + return true + default: + return false + } +} diff --git a/cmd/gost/route.go b/cmd/gost/route.go new file mode 100644 index 0000000..efa30db --- /dev/null +++ b/cmd/gost/route.go @@ -0,0 +1,734 @@ +package main + +import ( + "bufio" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "fmt" + "net" + "net/url" + "os" + "strings" + "time" + + "github.com/ginuerzh/gost" + "github.com/go-log/log" +) + +type stringList []string + +func (l *stringList) String() string { + return fmt.Sprintf("%s", *l) +} +func (l *stringList) Set(value string) error { + *l = append(*l, value) + return nil +} + +type route struct { + ServeNodes stringList + ChainNodes stringList + Retries int + Mark int +} + +func (r *route) parseChain() (*gost.Chain, error) { + chain := gost.NewChain() + chain.Retries = r.Retries + chain.Mark = r.Mark + gid := 1 // group ID + + for _, ns := range r.ChainNodes { + ngroup := gost.NewNodeGroup() + ngroup.ID = gid + gid++ + + // parse the base nodes + nodes, err := parseChainNode(ns) + if err != nil { + return nil, err + } + + nid := 1 // node ID + for i := range nodes { + nodes[i].ID = nid + nid++ + } + ngroup.AddNode(nodes...) + + ngroup.SetSelector(nil, + gost.WithFilter( + &gost.FailFilter{ + MaxFails: nodes[0].GetInt("max_fails"), + FailTimeout: nodes[0].GetDuration("fail_timeout"), + }, + &gost.InvalidFilter{}, + ), + gost.WithStrategy(gost.NewStrategy(nodes[0].Get("strategy"))), + ) + + if cfg := nodes[0].Get("peer"); cfg != "" { + f, err := os.Open(cfg) + if err != nil { + return nil, err + } + + peerCfg := newPeerConfig() + peerCfg.group = ngroup + peerCfg.baseNodes = nodes + peerCfg.Reload(f) + f.Close() + + go gost.PeriodReload(peerCfg, cfg) + } + + chain.AddNodeGroup(ngroup) + } + + return chain, nil +} + +func parseChainNode(ns string) (nodes []gost.Node, err error) { + node, err := gost.ParseNode(ns) + if err != nil { + return + } + + if auth := node.Get("auth"); auth != "" && node.User == nil { + c, err := base64.StdEncoding.DecodeString(auth) + if err != nil { + return nil, err + } + cs := string(c) + s := strings.IndexByte(cs, ':') + if s < 0 { + node.User = url.User(cs) + } else { + node.User = url.UserPassword(cs[:s], cs[s+1:]) + } + } + if node.User == nil { + users, err := parseUsers(node.Get("secrets")) + if err != nil { + return nil, err + } + if len(users) > 0 { + node.User = users[0] + } + } + + headerCfg := getHeaderCfg(node.Get("header")) + + serverName, sport, _ := net.SplitHostPort(node.Addr) + if serverName == "" { + serverName = "localhost" // default server name + } + sni := node.Get("sni") + if sni != "" { + serverName = sni + } + + rootCAs, err := loadCA(node.Get("ca")) + if err != nil { + return + } + tlsCfg := &tls.Config{ + ServerName: serverName, + InsecureSkipVerify: !node.GetBool("secure"), + RootCAs: rootCAs, + } + + // If the argument `ca` is given, but not open `secure`, we verify the + // certificate manually. + if rootCAs != nil && !node.GetBool("secure") { + tlsCfg.VerifyConnection = func(state tls.ConnectionState) error { + opts := x509.VerifyOptions{ + Roots: rootCAs, + CurrentTime: time.Now(), + DNSName: "", + Intermediates: x509.NewCertPool(), + } + + certs := state.PeerCertificates + for i, cert := range certs { + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) + } + + _, err = certs[0].Verify(opts) + return err + } + } + + if cert, err := tls.LoadX509KeyPair(node.Get("cert"), node.Get("key")); err == nil { + tlsCfg.Certificates = []tls.Certificate{cert} + } + + wsOpts := &gost.WSOptions{} + wsOpts.EnableCompression = node.GetBool("compression") + wsOpts.ReadBufferSize = node.GetInt("rbuf") + wsOpts.WriteBufferSize = node.GetInt("wbuf") + wsOpts.UserAgent = node.Get("agent") + wsOpts.Path = node.Get("path") + wsOpts.HeaderConfig = headerCfg + + timeout := node.GetDuration("timeout") + + var tr gost.Transporter + switch node.Transport { + case "tls": + tr = gost.TLSTransporter() + case "mtls": + tr = gost.MTLSTransporter() + case "ws": + tr = gost.WSTransporter(wsOpts) + case "mws": + tr = gost.MWSTransporter(wsOpts) + case "wss": + tr = gost.WSSTransporter(wsOpts) + case "mwss": + tr = gost.MWSSTransporter(wsOpts) + case "kcp": + config, err := parseKCPConfig(node.Get("c")) + if err != nil { + return nil, err + } + if config == nil { + conf := gost.DefaultKCPConfig + if node.GetBool("tcp") { + conf.TCP = true + } + config = &conf + } + tr = gost.KCPTransporter(config) + case "ssh": + if node.Protocol == "direct" || node.Protocol == "remote" { + tr = gost.SSHForwardTransporter() + } else { + tr = gost.SSHTunnelTransporter() + } + case "quic": + config := &gost.QUICConfig{ + TLSConfig: tlsCfg, + KeepAlive: node.GetBool("keepalive"), + Timeout: timeout, + IdleTimeout: node.GetDuration("idle"), + } + + if cipher := node.Get("cipher"); cipher != "" { + sum := sha256.Sum256([]byte(cipher)) + config.Key = sum[:] + } + + tr = gost.QUICTransporter(config) + case "http2": + tr = gost.HTTP2Transporter(tlsCfg) + case "h2": + tr = gost.H2Transporter(tlsCfg, node.Get("path")) + case "h2c": + tr = gost.H2CTransporter(node.Get("path")) + case "obfs4": + tr = gost.Obfs4Transporter() + case "ohttp": + tr = gost.ObfsHTTPTransporter() + case "otls": + tr = gost.ObfsTLSTransporter() + case "ftcp": + tr = gost.FakeTCPTransporter() + case "udp": + tr = gost.UDPTransporter() + default: + tr = gost.TCPTransporter() + } + + var connector gost.Connector + switch node.Protocol { + case "http2": + connector = gost.HTTP2Connector(node.User) + case "socks", "socks5": + connector = gost.SOCKS5Connector(node.User) + case "socks4": + connector = gost.SOCKS4Connector() + case "socks4a": + connector = gost.SOCKS4AConnector() + case "ss": + connector = gost.ShadowConnector(node.User) + case "ssu": + connector = gost.ShadowUDPConnector(node.User) + case "direct": + connector = gost.SSHDirectForwardConnector() + case "remote": + connector = gost.SSHRemoteForwardConnector() + case "forward": + connector = gost.ForwardConnector() + case "sni": + connector = gost.SNIConnector(node.Get("host")) + case "http": + connector = gost.HTTPConnector(node.User) + case "relay": + connector = gost.RelayConnector(node.User) + default: + connector = gost.AutoConnector(node.User) + } + + host := node.Get("host") + if host == "" { + if sni != "" { + index := strings.Index(node.Host, ":") + if index < 0 { + host = sni + } else { + host = sni + node.Host[index:] + } + } else { + host = node.Host + } + + } + + node.DialOptions = append(node.DialOptions, + gost.TimeoutDialOption(timeout), + gost.HostDialOption(host), + gost.HeaderConfigDialOption(headerCfg), + ) + + node.ConnectOptions = []gost.ConnectOption{ + gost.UserAgentConnectOption(node.Get("agent")), + gost.NoTLSConnectOption(node.GetBool("notls")), + gost.NoDelayConnectOption(node.GetBool("nodelay")), + gost.HeaderConnectOption(headerCfg), + } + + sshConfig := &gost.SSHConfig{} + if s := node.Get("ssh_key"); s != "" { + key, err := gost.ParseSSHKeyFile(s) + if err != nil { + return nil, err + } + sshConfig.Key = key + } + handshakeOptions := []gost.HandshakeOption{ + gost.AddrHandshakeOption(node.Addr), + gost.HostHandshakeOption(host), + gost.UserHandshakeOption(node.User), + gost.TLSConfigHandshakeOption(tlsCfg), + gost.IntervalHandshakeOption(node.GetDuration("ping")), + gost.TimeoutHandshakeOption(timeout), + gost.RetryHandshakeOption(node.GetInt("retry")), + gost.SSHConfigHandshakeOption(sshConfig), + } + + node.Client = &gost.Client{ + Connector: connector, + Transporter: tr, + } + + node.Bypass = parseBypass(node.Get("bypass")) + + ips := parseIP(node.Get("ip"), sport) + for _, ip := range ips { + nd := node.Clone() + nd.Addr = ip + // override the default node address + nd.HandshakeOptions = append(handshakeOptions, gost.AddrHandshakeOption(ip)) + // One node per IP + nodes = append(nodes, nd) + } + if len(ips) == 0 { + node.HandshakeOptions = handshakeOptions + nodes = []gost.Node{node} + } + + if node.Transport == "obfs4" { + for i := range nodes { + if err := gost.Obfs4Init(nodes[i], false); err != nil { + return nil, err + } + } + } + + return +} + +func (r *route) GenRouters() ([]router, error) { + chain, err := r.parseChain() + if err != nil { + return nil, err + } + + var rts []router + + for _, ns := range r.ServeNodes { + node, err := gost.ParseNode(ns) + if err != nil { + return nil, err + } + + if auth := node.Get("auth"); auth != "" && node.User == nil { + c, err := base64.StdEncoding.DecodeString(auth) + if err != nil { + return nil, err + } + cs := string(c) + s := strings.IndexByte(cs, ':') + if s < 0 { + node.User = url.User(cs) + } else { + node.User = url.UserPassword(cs[:s], cs[s+1:]) + } + } + authenticator, err := parseAuthenticator(node.Get("secrets")) + if err != nil { + return nil, err + } + if authenticator == nil && node.User != nil { + kvs := make(map[string]string) + kvs[node.User.Username()], _ = node.User.Password() + authenticator = gost.NewLocalAuthenticator(kvs) + } + if node.User == nil { + if users, _ := parseUsers(node.Get("secrets")); len(users) > 0 { + node.User = users[0] + } + } + certFile, keyFile := node.Get("cert"), node.Get("key") + tlsCfg, err := tlsConfig(certFile, keyFile, node.Get("ca")) + if err != nil && certFile != "" && keyFile != "" { + return nil, err + } + + wsOpts := &gost.WSOptions{} + wsOpts.EnableCompression = node.GetBool("compression") + wsOpts.ReadBufferSize = node.GetInt("rbuf") + wsOpts.WriteBufferSize = node.GetInt("wbuf") + wsOpts.Path = node.Get("path") + + ttl := node.GetDuration("ttl") + timeout := node.GetDuration("timeout") + + tunRoutes := parseIPRoutes(node.Get("route")) + gw := net.ParseIP(node.Get("gw")) // default gateway + for i := range tunRoutes { + if tunRoutes[i].Gateway == nil { + tunRoutes[i].Gateway = gw + } + } + + var ln gost.Listener + switch node.Transport { + case "tls": + ln, err = gost.TLSListener(node.Addr, tlsCfg) + case "mtls": + ln, err = gost.MTLSListener(node.Addr, tlsCfg) + case "ws": + ln, err = gost.WSListener(node.Addr, wsOpts) + case "mws": + ln, err = gost.MWSListener(node.Addr, wsOpts) + case "wss": + ln, err = gost.WSSListener(node.Addr, tlsCfg, wsOpts) + case "mwss": + ln, err = gost.MWSSListener(node.Addr, tlsCfg, wsOpts) + case "kcp": + config, er := parseKCPConfig(node.Get("c")) + if er != nil { + return nil, er + } + if config == nil { + conf := gost.DefaultKCPConfig + if node.GetBool("tcp") { + conf.TCP = true + } + config = &conf + } + ln, err = gost.KCPListener(node.Addr, config) + case "ssh": + config := &gost.SSHConfig{ + Authenticator: authenticator, + TLSConfig: tlsCfg, + } + if s := node.Get("ssh_key"); s != "" { + key, err := gost.ParseSSHKeyFile(s) + if err != nil { + return nil, err + } + config.Key = key + } + if s := node.Get("ssh_authorized_keys"); s != "" { + keys, err := gost.ParseSSHAuthorizedKeysFile(s) + if err != nil { + return nil, err + } + config.AuthorizedKeys = keys + } + if node.Protocol == "forward" { + ln, err = gost.TCPListener(node.Addr) + } else { + ln, err = gost.SSHTunnelListener(node.Addr, config) + } + case "quic": + config := &gost.QUICConfig{ + TLSConfig: tlsCfg, + KeepAlive: node.GetBool("keepalive"), + Timeout: timeout, + IdleTimeout: node.GetDuration("idle"), + } + if cipher := node.Get("cipher"); cipher != "" { + sum := sha256.Sum256([]byte(cipher)) + config.Key = sum[:] + } + + ln, err = gost.QUICListener(node.Addr, config) + case "http2": + ln, err = gost.HTTP2Listener(node.Addr, tlsCfg) + case "h2": + ln, err = gost.H2Listener(node.Addr, tlsCfg, node.Get("path")) + case "h2c": + ln, err = gost.H2CListener(node.Addr, node.Get("path")) + case "tcp": + // Directly use SSH port forwarding if the last chain node is forward+ssh + if chain.LastNode().Protocol == "forward" && chain.LastNode().Transport == "ssh" { + chain.Nodes()[len(chain.Nodes())-1].Client.Connector = gost.SSHDirectForwardConnector() + chain.Nodes()[len(chain.Nodes())-1].Client.Transporter = gost.SSHForwardTransporter() + } + ln, err = gost.TCPListener(node.Addr) + case "udp": + ln, err = gost.UDPListener(node.Addr, &gost.UDPListenConfig{ + TTL: ttl, + Backlog: node.GetInt("backlog"), + QueueSize: node.GetInt("queue"), + }) + case "rtcp": + // Directly use SSH port forwarding if the last chain node is forward+ssh + if chain.LastNode().Protocol == "forward" && chain.LastNode().Transport == "ssh" { + chain.Nodes()[len(chain.Nodes())-1].Client.Connector = gost.SSHRemoteForwardConnector() + chain.Nodes()[len(chain.Nodes())-1].Client.Transporter = gost.SSHForwardTransporter() + } + ln, err = gost.TCPRemoteForwardListener(node.Addr, chain) + case "rudp": + ln, err = gost.UDPRemoteForwardListener(node.Addr, + chain, + &gost.UDPListenConfig{ + TTL: ttl, + Backlog: node.GetInt("backlog"), + QueueSize: node.GetInt("queue"), + }) + case "obfs4": + if err = gost.Obfs4Init(node, true); err != nil { + return nil, err + } + ln, err = gost.Obfs4Listener(node.Addr) + case "ohttp": + ln, err = gost.ObfsHTTPListener(node.Addr) + case "otls": + ln, err = gost.ObfsTLSListener(node.Addr) + case "tun": + cfg := gost.TunConfig{ + Name: node.Get("name"), + Addr: node.Get("net"), + Peer: node.Get("peer"), + MTU: node.GetInt("mtu"), + Routes: tunRoutes, + Gateway: node.Get("gw"), + } + ln, err = gost.TunListener(cfg) + case "tap": + cfg := gost.TapConfig{ + Name: node.Get("name"), + Addr: node.Get("net"), + MTU: node.GetInt("mtu"), + Routes: strings.Split(node.Get("route"), ","), + Gateway: node.Get("gw"), + } + ln, err = gost.TapListener(cfg) + case "ftcp": + ln, err = gost.FakeTCPListener( + node.Addr, + &gost.FakeTCPListenConfig{ + TTL: ttl, + Backlog: node.GetInt("backlog"), + QueueSize: node.GetInt("queue"), + }, + ) + case "dns": + ln, err = gost.DNSListener( + node.Addr, + &gost.DNSOptions{ + Mode: node.Get("mode"), + TLSConfig: tlsCfg, + }, + ) + case "redu", "redirectu": + ln, err = gost.UDPRedirectListener(node.Addr, &gost.UDPListenConfig{ + TTL: ttl, + Backlog: node.GetInt("backlog"), + QueueSize: node.GetInt("queue"), + }) + default: + ln, err = gost.TCPListener(node.Addr) + } + if err != nil { + return nil, err + } + + var handler gost.Handler + switch node.Protocol { + case "http2": + handler = gost.HTTP2Handler() + case "socks", "socks5": + handler = gost.SOCKS5Handler() + case "socks4", "socks4a": + handler = gost.SOCKS4Handler() + case "ss": + handler = gost.ShadowHandler() + case "http": + handler = gost.HTTPHandler() + case "tcp": + handler = gost.TCPDirectForwardHandler(node.Remote) + case "rtcp": + handler = gost.TCPRemoteForwardHandler(node.Remote) + case "udp": + handler = gost.UDPDirectForwardHandler(node.Remote) + case "rudp": + handler = gost.UDPRemoteForwardHandler(node.Remote) + case "forward": + handler = gost.SSHForwardHandler() + case "red", "redirect": + handler = gost.TCPRedirectHandler() + case "redu", "redirectu": + handler = gost.UDPRedirectHandler() + case "ssu": + handler = gost.ShadowUDPHandler() + case "sni": + handler = gost.SNIHandler() + case "tun": + handler = gost.TunHandler() + case "tap": + handler = gost.TapHandler() + case "dns": + handler = gost.DNSHandler(node.Remote) + case "relay": + handler = gost.RelayHandler(node.Remote) + default: + // start from 2.5, if remote is not empty, then we assume that it is a forward tunnel. + if node.Remote != "" { + handler = gost.TCPDirectForwardHandler(node.Remote) + } else { + handler = gost.AutoHandler() + } + } + + var whitelist, blacklist *gost.Permissions + if node.Values.Get("whitelist") != "" { + if whitelist, err = gost.ParsePermissions(node.Get("whitelist")); err != nil { + return nil, err + } + } + if node.Values.Get("blacklist") != "" { + if blacklist, err = gost.ParsePermissions(node.Get("blacklist")); err != nil { + return nil, err + } + } + + node.Bypass = parseBypass(node.Get("bypass")) + hosts := parseHosts(node.Get("hosts")) + ips := parseIP(node.Get("ip"), "") + + resolver := parseResolver(node.Get("dns")) + if resolver != nil { + resolver.Init( + gost.ChainResolverOption(chain), + gost.TimeoutResolverOption(timeout), + gost.TTLResolverOption(ttl), + gost.PreferResolverOption(node.Get("prefer")), + gost.SrcIPResolverOption(net.ParseIP(node.Get("ip"))), + ) + } + + handler.Init( + gost.AddrHandlerOption(ln.Addr().String()), + gost.ChainHandlerOption(chain), + gost.UsersHandlerOption(node.User), + gost.AuthenticatorHandlerOption(authenticator), + gost.TLSConfigHandlerOption(tlsCfg), + gost.WhitelistHandlerOption(whitelist), + gost.BlacklistHandlerOption(blacklist), + gost.StrategyHandlerOption(gost.NewStrategy(node.Get("strategy"))), + gost.MaxFailsHandlerOption(node.GetInt("max_fails")), + gost.FailTimeoutHandlerOption(node.GetDuration("fail_timeout")), + gost.BypassHandlerOption(node.Bypass), + gost.ResolverHandlerOption(resolver), + gost.HostsHandlerOption(hosts), + gost.RetryHandlerOption(node.GetInt("retry")), // override the global retry option. + gost.TimeoutHandlerOption(timeout), + gost.ProbeResistHandlerOption(node.Get("probe_resist")), + gost.KnockingHandlerOption(node.Get("knock")), + gost.NodeHandlerOption(node), + gost.IPsHandlerOption(ips), + gost.TCPModeHandlerOption(node.GetBool("tcp")), + gost.IPRoutesHandlerOption(tunRoutes...), + ) + + rt := router{ + node: node, + server: &gost.Server{Listener: ln}, + handler: handler, + chain: chain, + resolver: resolver, + hosts: hosts, + } + rts = append(rts, rt) + } + + return rts, nil +} + +type router struct { + node gost.Node + server *gost.Server + handler gost.Handler + chain *gost.Chain + resolver gost.Resolver + hosts *gost.Hosts +} + +func (r *router) Serve() error { + log.Logf("%s on %s", r.node.String(), r.server.Addr()) + return r.server.Serve(r.handler) +} + +func (r *router) Close() error { + if r == nil || r.server == nil { + return nil + } + return r.server.Close() +} + +func getHeaderCfg(headerFile string) map[string]string { + var h map[string]string + h = make(map[string]string) + if headerFile == "" { + return nil + } + f, err := os.Open(headerFile) + if err != nil { + return nil + } + defer f.Close() + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := scanner.Text() + index := strings.Index(line, ": ") + if index < 0 { + continue + } + key := line[0:index] + val := line[index+2:] + h[key] = val + + } + return h +} diff --git a/common_test.go b/common_test.go new file mode 100644 index 0000000..9e030d6 --- /dev/null +++ b/common_test.go @@ -0,0 +1,260 @@ +package gost + +import ( + "bufio" + "bytes" + "crypto/tls" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "sync" + "time" + + "github.com/go-log/log" +) + +func init() { + SetLogger(&NopLogger{}) + // SetLogger(&LogLogger{}) + Debug = true + DialTimeout = 1000 * time.Millisecond + HandshakeTimeout = 1000 * time.Millisecond + ConnectTimeout = 1000 * time.Millisecond + + cert, err := GenCertificate() + if err != nil { + panic(err) + } + DefaultTLSConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } +} + +var ( + httpTestHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + data, _ := ioutil.ReadAll(r.Body) + if len(data) == 0 { + data = []byte("Hello World!") + } + io.Copy(w, bytes.NewReader(data)) + }) + + udpTestHandler = udpHandlerFunc(func(w io.Writer, r *udpRequest) { + io.Copy(w, r.Body) + }) +) + +// proxyConn obtains a connection to the proxy server. +func proxyConn(client *Client, server *Server) (net.Conn, error) { + conn, err := client.Dial(server.Addr().String()) + if err != nil { + return nil, err + } + + cc, err := client.Handshake(conn, AddrHandshakeOption(server.Addr().String())) + if err != nil { + conn.Close() + return nil, err + } + + return cc, nil +} + +// httpRoundtrip does a HTTP request-response roundtrip, and checks the data received. +func httpRoundtrip(conn net.Conn, targetURL string, data []byte) (err error) { + req, err := http.NewRequest( + http.MethodGet, + targetURL, + bytes.NewReader(data), + ) + if err != nil { + return + } + if err = req.Write(conn); err != nil { + return + } + resp, err := http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errors.New(resp.Status) + } + + recv, err := ioutil.ReadAll(resp.Body) + if err != nil { + return + } + + if !bytes.Equal(data, recv) { + return fmt.Errorf("data not equal") + } + return +} + +func udpRoundtrip(logger log.Logger, client *Client, server *Server, host string, data []byte) (err error) { + conn, err := proxyConn(client, server) + if err != nil { + return + } + defer conn.Close() + + conn, err = client.Connect(conn, host) + if err != nil { + return + } + + conn.SetDeadline(time.Now().Add(1 * time.Second)) + defer conn.SetDeadline(time.Time{}) + + if _, err = conn.Write(data); err != nil { + logger.Logf("write to %s via %s: %s", host, server.Addr(), err) + return + } + + recv := make([]byte, len(data)) + if _, err = conn.Read(recv); err != nil { + logger.Logf("read from %s via %s: %s", host, server.Addr(), err) + return + } + + if !bytes.Equal(data, recv) { + return fmt.Errorf("data not equal") + } + + return +} + +func proxyRoundtrip(client *Client, server *Server, targetURL string, data []byte) (err error) { + conn, err := proxyConn(client, server) + if err != nil { + return err + } + defer conn.Close() + + u, err := url.Parse(targetURL) + if err != nil { + return + } + + conn, err = client.Connect(conn, u.Host) + if err != nil { + return + } + + conn.SetDeadline(time.Now().Add(1000 * time.Millisecond)) + defer conn.SetDeadline(time.Time{}) + + return httpRoundtrip(conn, targetURL, data) +} + +type udpRequest struct { + Body io.Reader + RemoteAddr string +} + +type udpResponseWriter struct { + conn net.PacketConn + addr net.Addr +} + +func (w *udpResponseWriter) Write(p []byte) (int, error) { + return w.conn.WriteTo(p, w.addr) +} + +type udpHandlerFunc func(w io.Writer, r *udpRequest) + +// udpTestServer is a UDP server for test. +type udpTestServer struct { + ln net.PacketConn + handler udpHandlerFunc + wg sync.WaitGroup + mu sync.Mutex // guards closed and conns + closed bool + startChan chan struct{} + exitChan chan struct{} +} + +func newUDPTestServer(handler udpHandlerFunc) *udpTestServer { + laddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:0") + ln, err := net.ListenUDP("udp", laddr) + if err != nil { + panic(fmt.Sprintf("udptest: failed to listen on a port: %v", err)) + } + + return &udpTestServer{ + ln: ln, + handler: handler, + startChan: make(chan struct{}), + exitChan: make(chan struct{}), + } +} + +func (s *udpTestServer) Start() { + go s.serve() + <-s.startChan +} + +func (s *udpTestServer) serve() { + select { + case <-s.startChan: + return + default: + close(s.startChan) + } + + for { + data := make([]byte, 32*1024) + n, raddr, err := s.ln.ReadFrom(data) + if err != nil { + break + } + if s.handler != nil { + s.wg.Add(1) + go func() { + defer s.wg.Done() + w := &udpResponseWriter{ + conn: s.ln, + addr: raddr, + } + r := &udpRequest{ + Body: bytes.NewReader(data[:n]), + RemoteAddr: raddr.String(), + } + s.handler(w, r) + }() + } + } + + // signal the listener has been exited. + close(s.exitChan) +} + +func (s *udpTestServer) Addr() string { + return s.ln.LocalAddr().String() +} + +func (s *udpTestServer) Close() error { + s.mu.Lock() + + if s.closed { + s.mu.Unlock() + return nil + } + + err := s.ln.Close() + s.closed = true + s.mu.Unlock() + + <-s.exitChan + + s.wg.Wait() + + return err +} diff --git a/dns.go b/dns.go new file mode 100644 index 0000000..1b02404 --- /dev/null +++ b/dns.go @@ -0,0 +1,422 @@ +package gost + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/base64" + "errors" + "io" + "io/ioutil" + "net" + "net/http" + "strconv" + "strings" + "time" + + "github.com/go-log/log" + "github.com/miekg/dns" +) + +var ( + defaultResolver Resolver +) + +func init() { + defaultResolver = NewResolver( + DefaultResolverTimeout, + NameServer{ + Addr: "127.0.0.1:53", + Protocol: "udp", + }) + defaultResolver.Init() +} + +type dnsHandler struct { + options *HandlerOptions +} + +// DNSHandler creates a Handler for DNS server. +func DNSHandler(raddr string, opts ...HandlerOption) Handler { + h := &dnsHandler{} + + for _, opt := range opts { + opt(h.options) + } + return h +} + +func (h *dnsHandler) Init(opts ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} + } + + for _, opt := range opts { + opt(h.options) + } +} + +func (h *dnsHandler) Handle(conn net.Conn) { + defer conn.Close() + + b := mPool.Get().([]byte) + defer mPool.Put(b) + + n, err := conn.Read(b) + if err != nil { + log.Logf("[dns] %s - %s: %v", conn.RemoteAddr(), conn.LocalAddr(), err) + } + + mq := &dns.Msg{} + if err = mq.Unpack(b[:n]); err != nil { + log.Logf("[dns] %s - %s request unpack: %v", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + log.Logf("[dns] %s -> %s: %s", conn.RemoteAddr(), conn.LocalAddr(), h.dumpMsgHeader(mq)) + if Debug { + log.Logf("[dns] %s >>> %s: %s", conn.RemoteAddr(), conn.LocalAddr(), mq.String()) + } + + start := time.Now() + + resolver := h.options.Resolver + if resolver == nil { + resolver = defaultResolver + } + reply, err := resolver.Exchange(context.Background(), b[:n]) + if err != nil { + log.Logf("[dns] %s - %s exchange: %v", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + + rtt := time.Since(start) + + mr := &dns.Msg{} + if err = mr.Unpack(reply); err != nil { + log.Logf("[dns] %s - %s reply unpack: %v", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + log.Logf("[dns] %s <- %s: %s [%s]", + conn.RemoteAddr(), conn.LocalAddr(), h.dumpMsgHeader(mr), rtt) + if Debug { + log.Logf("[dns] %s <<< %s: %s", conn.RemoteAddr(), conn.LocalAddr(), mr.String()) + } + + if _, err = conn.Write(reply); err != nil { + log.Logf("[dns] %s - %s reply unpack: %v", conn.RemoteAddr(), conn.LocalAddr(), err) + } +} + +func (h *dnsHandler) dumpMsgHeader(m *dns.Msg) string { + buf := new(bytes.Buffer) + buf.WriteString(m.MsgHdr.String() + " ") + buf.WriteString("QUERY: " + strconv.Itoa(len(m.Question)) + ", ") + buf.WriteString("ANSWER: " + strconv.Itoa(len(m.Answer)) + ", ") + buf.WriteString("AUTHORITY: " + strconv.Itoa(len(m.Ns)) + ", ") + buf.WriteString("ADDITIONAL: " + strconv.Itoa(len(m.Extra))) + return buf.String() +} + +// DNSOptions is options for DNS Listener. +type DNSOptions struct { + Mode string + UDPSize int + ReadTimeout time.Duration + WriteTimeout time.Duration + TLSConfig *tls.Config +} + +type dnsListener struct { + addr net.Addr + server dnsServer + connChan chan net.Conn + errc chan error +} + +// DNSListener creates a Listener for DNS proxy server. +func DNSListener(addr string, options *DNSOptions) (Listener, error) { + if options == nil { + options = &DNSOptions{} + } + + tlsConfig := options.TLSConfig + if tlsConfig == nil { + tlsConfig = DefaultTLSConfig + } + + ln := &dnsListener{ + connChan: make(chan net.Conn, 128), + errc: make(chan error, 1), + } + + var srv dnsServer + var err error + switch strings.ToLower(options.Mode) { + case "tcp": + srv = &dns.Server{ + Net: "tcp", + Addr: addr, + Handler: ln, + ReadTimeout: options.ReadTimeout, + WriteTimeout: options.WriteTimeout, + } + case "tls": + srv = &dns.Server{ + Net: "tcp-tls", + Addr: addr, + Handler: ln, + TLSConfig: tlsConfig, + ReadTimeout: options.ReadTimeout, + WriteTimeout: options.WriteTimeout, + } + case "https": + srv = &dohServer{ + addr: addr, + tlsConfig: tlsConfig, + server: &http.Server{ + Handler: ln, + ReadTimeout: options.ReadTimeout, + WriteTimeout: options.WriteTimeout, + }, + } + + default: + ln.addr, err = net.ResolveTCPAddr("tcp", addr) + srv = &dns.Server{ + Net: "udp", + Addr: addr, + Handler: ln, + UDPSize: options.UDPSize, + ReadTimeout: options.ReadTimeout, + WriteTimeout: options.WriteTimeout, + } + } + if err != nil { + return nil, err + } + + if ln.addr == nil { + ln.addr, err = net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + } + + ln.server = srv + + go func() { + if err := ln.server.ListenAndServe(); err != nil { + ln.errc <- err + return + } + }() + + select { + case err := <-ln.errc: + return nil, err + default: + } + + return ln, nil +} + +func (l *dnsListener) serve(w dnsResponseWriter, mq []byte) (err error) { + conn := newDNSServerConn(l.addr, w.RemoteAddr()) + conn.mq <- mq + + select { + case l.connChan <- conn: + default: + return errors.New("connection queue is full") + } + + select { + case mr := <-conn.mr: + _, err = w.Write(mr) + case <-conn.cclose: + err = io.EOF + } + return +} + +func (l *dnsListener) ServeDNS(w dns.ResponseWriter, m *dns.Msg) { + b, err := m.Pack() + if err != nil { + log.Logf("[dns] %s: %v", l.addr, err) + return + } + if err := l.serve(w, b); err != nil { + log.Logf("[dns] %s: %v", l.addr, err) + } +} + +// Based on https://github.com/semihalev/sdns +func (l *dnsListener) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var buf []byte + var err error + switch r.Method { + case http.MethodGet: + buf, err = base64.RawURLEncoding.DecodeString(r.URL.Query().Get("dns")) + if len(buf) == 0 || err != nil { + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + return + } + case http.MethodPost: + if r.Header.Get("Content-Type") != "application/dns-message" { + http.Error(w, http.StatusText(http.StatusUnsupportedMediaType), http.StatusUnsupportedMediaType) + return + } + + buf, err = ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + default: + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + return + } + + mq := &dns.Msg{} + if err := mq.Unpack(buf); err != nil { + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + return + } + + w.Header().Set("Server", "SDNS") + w.Header().Set("Content-Type", "application/dns-message") + + raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr) + if err := l.serve(newDoHResponseWriter(raddr, w), buf); err != nil { + log.Logf("[dns] %s: %v", l.addr, err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + } +} + +func (l *dnsListener) Accept() (conn net.Conn, err error) { + select { + case conn = <-l.connChan: + case err = <-l.errc: + } + return +} + +func (l *dnsListener) Close() error { + return l.server.Shutdown() +} + +func (l *dnsListener) Addr() net.Addr { + return l.addr +} + +type dnsServer interface { + ListenAndServe() error + Shutdown() error +} + +type dohServer struct { + addr string + tlsConfig *tls.Config + server *http.Server +} + +func (s *dohServer) ListenAndServe() error { + ln, err := net.Listen("tcp", s.addr) + if err != nil { + return err + } + ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, s.tlsConfig) + return s.server.Serve(ln) +} + +func (s *dohServer) Shutdown() error { + return s.server.Shutdown(context.Background()) +} + +type dnsServerConn struct { + mq chan []byte + mr chan []byte + cclose chan struct{} + laddr, raddr net.Addr +} + +func newDNSServerConn(laddr, raddr net.Addr) *dnsServerConn { + return &dnsServerConn{ + mq: make(chan []byte, 1), + mr: make(chan []byte, 1), + laddr: laddr, + raddr: raddr, + cclose: make(chan struct{}), + } +} + +func (c *dnsServerConn) Read(b []byte) (n int, err error) { + select { + case mb := <-c.mq: + n = copy(b, mb) + case <-c.cclose: + err = errors.New("connection is closed") + } + return +} + +func (c *dnsServerConn) Write(b []byte) (n int, err error) { + select { + case c.mr <- b: + n = len(b) + case <-c.cclose: + err = errors.New("broken pipe") + } + + return +} + +func (c *dnsServerConn) Close() error { + select { + case <-c.cclose: + default: + close(c.cclose) + } + return nil +} + +func (c *dnsServerConn) LocalAddr() net.Addr { + return c.laddr +} + +func (c *dnsServerConn) RemoteAddr() net.Addr { + return c.raddr +} + +func (c *dnsServerConn) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *dnsServerConn) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *dnsServerConn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +type dnsResponseWriter interface { + io.Writer + RemoteAddr() net.Addr +} + +type dohResponseWriter struct { + raddr net.Addr + http.ResponseWriter +} + +func newDoHResponseWriter(raddr net.Addr, w http.ResponseWriter) dnsResponseWriter { + return &dohResponseWriter{ + raddr: raddr, + ResponseWriter: w, + } +} + +func (w *dohResponseWriter) RemoteAddr() net.Addr { + return w.raddr +} diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 0000000..a2eb077 --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,4 @@ +version: "3.4" +services: + gost: + build: . diff --git a/examples/bench/cli.go b/examples/bench/cli.go new file mode 100644 index 0000000..57c189c --- /dev/null +++ b/examples/bench/cli.go @@ -0,0 +1,220 @@ +package main + +import ( + "bufio" + "flag" + "log" + "net/http" + "net/http/httputil" + "net/url" + "sync" + "time" + + "github.com/ginuerzh/gost" + "golang.org/x/net/http2" +) + +var ( + requests, concurrency int + quiet bool + swg, ewg sync.WaitGroup +) + +func init() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + flag.IntVar(&requests, "n", 1, "Number of requests to perform") + flag.IntVar(&concurrency, "c", 1, "Number of multiple requests to make at a time") + flag.BoolVar(&quiet, "q", false, "quiet mode") + flag.BoolVar(&http2.VerboseLogs, "v", false, "HTTP2 verbose logs") + flag.BoolVar(&gost.Debug, "d", false, "debug mode") + flag.Parse() + + if quiet { + gost.SetLogger(&gost.NopLogger{}) + } +} + +func main() { + chain := gost.NewChain( + + /* + // http+tcp + gost.Node{ + Addr: "127.0.0.1:18080", + Client: gost.NewClient( + gost.HTTPConnector(url.UserPassword("admin", "123456")), + gost.TCPTransporter(), + ), + }, + */ + + /* + // socks5+tcp + gost.Node{ + Addr: "127.0.0.1:11080", + Client: gost.NewClient( + gost.SOCKS5Connector(url.UserPassword("admin", "123456")), + gost.TCPTransporter(), + ), + }, + */ + + /* + // ss+tcp + gost.Node{ + Addr: "127.0.0.1:18338", + Client: gost.NewClient( + gost.ShadowConnector(url.UserPassword("chacha20", "123456")), + gost.TCPTransporter(), + ), + }, + */ + + /* + // http+ws + gost.Node{ + Addr: "127.0.0.1:18000", + Client: gost.NewClient( + gost.HTTPConnector(url.UserPassword("admin", "123456")), + gost.WSTransporter(nil), + ), + }, + */ + + /* + // http+wss + gost.Node{ + Addr: "127.0.0.1:18443", + Client: gost.NewClient( + gost.HTTPConnector(url.UserPassword("admin", "123456")), + gost.WSSTransporter(nil), + ), + }, + */ + + /* + // http+tls + gost.Node{ + Addr: "127.0.0.1:11443", + Client: gost.NewClient( + gost.HTTPConnector(url.UserPassword("admin", "123456")), + gost.TLSTransporter(), + ), + }, + */ + + /* + // http2 + gost.Node{ + Addr: "127.0.0.1:1443", + Client: &gost.Client{ + Connector: gost.HTTP2Connector(url.UserPassword("admin", "123456")), + Transporter: gost.HTTP2Transporter(nil), + }, + }, + */ + + /* + // http+kcp + gost.Node{ + Addr: "127.0.0.1:18388", + Client: gost.NewClient( + gost.HTTPConnector(nil), + gost.KCPTransporter(nil), + ), + }, + */ + + /* + // http+ssh + gost.Node{ + Addr: "127.0.0.1:12222", + Client: gost.NewClient( + gost.HTTPConnector(url.UserPassword("admin", "123456")), + gost.SSHTunnelTransporter(), + ), + }, + */ + + /* + // http+quic + gost.Node{ + Addr: "localhost:6121", + Client: &gost.Client{ + Connector: gost.HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: gost.QUICTransporter(nil), + }, + }, + */ + // socks5+h2 + gost.Node{ + Addr: "localhost:8443", + Client: &gost.Client{ + // Connector: gost.HTTPConnector(url.UserPassword("admin", "123456")), + Connector: gost.SOCKS5Connector(url.UserPassword("admin", "123456")), + // Transporter: gost.H2CTransporter(), // HTTP2 h2c mode + Transporter: gost.H2Transporter(nil), // HTTP2 h2 + }, + }, + ) + + total := 0 + for total < requests { + if total+concurrency > requests { + concurrency = requests - total + } + startChan := make(chan struct{}) + for i := 0; i < concurrency; i++ { + swg.Add(1) + ewg.Add(1) + go request(chain, startChan) + } + + start := time.Now() + swg.Wait() // wait for workers ready + close(startChan) // start signal + ewg.Wait() // wait for workers done + + duration := time.Since(start) + total += concurrency + log.Printf("%d/%d/%d requests done (%v/%v)", total, requests, concurrency, duration, duration/time.Duration(concurrency)) + } +} + +func request(chain *gost.Chain, start <-chan struct{}) { + defer ewg.Done() + + swg.Done() + <-start + + conn, err := chain.Dial("localhost:18888") + if err != nil { + log.Println(err) + return + } + defer conn.Close() + //conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true}) + req, err := http.NewRequest(http.MethodGet, "http://localhost:18888", nil) + if err != nil { + log.Println(err) + return + } + if err := req.Write(conn); err != nil { + log.Println(err) + return + } + resp, err := http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + log.Println(err) + return + } + defer resp.Body.Close() + + if gost.Debug { + rb, _ := httputil.DumpRequest(req, true) + log.Println(string(rb)) + rb, _ = httputil.DumpResponse(resp, true) + log.Println(string(rb)) + } +} diff --git a/examples/bench/srv.go b/examples/bench/srv.go new file mode 100644 index 0000000..36da8e4 --- /dev/null +++ b/examples/bench/srv.go @@ -0,0 +1,359 @@ +package main + +import ( + "crypto/tls" + "flag" + "fmt" + "log" + "net/http" + "net/url" + "time" + + "github.com/ginuerzh/gost" + "golang.org/x/net/http2" +) + +var ( + quiet bool +) + +func init() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + flag.BoolVar(&quiet, "q", false, "quiet mode") + flag.BoolVar(&gost.Debug, "d", false, "debug mode") + flag.BoolVar(&http2.VerboseLogs, "v", false, "HTTP2 verbose logs") + flag.Parse() + + if quiet { + gost.SetLogger(&gost.NopLogger{}) + } +} + +func main() { + go httpServer() + go socks5Server() + go tlsServer() + go shadowServer() + go wsServer() + go wssServer() + go kcpServer() + go tcpForwardServer() + go tcpRemoteForwardServer() + // go rudpForwardServer() + // go tcpRedirectServer() + go sshTunnelServer() + go http2Server() + go http2TunnelServer() + go quicServer() + go shadowUDPServer() + go testServer() + select {} +} + +func httpServer() { + ln, err := gost.TCPListener(":18080") + if err != nil { + log.Fatal(err) + } + h := gost.HTTPHandler( + gost.UsersHandlerOption(url.UserPassword("admin", "123456")), + ) + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +func socks5Server() { + ln, err := gost.TCPListener(":11080") + if err != nil { + log.Fatal(err) + } + h := gost.SOCKS5Handler( + gost.UsersHandlerOption(url.UserPassword("admin", "123456")), + gost.TLSConfigHandlerOption(tlsConfig()), + ) + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +func shadowServer() { + ln, err := gost.TCPListener(":18338") + if err != nil { + log.Fatal(err) + } + h := gost.ShadowHandler( + gost.UsersHandlerOption(url.UserPassword("chacha20", "123456")), + ) + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +func tlsServer() { + ln, err := gost.TLSListener(":11443", tlsConfig()) + if err != nil { + log.Fatal(err) + } + h := gost.HTTPHandler( + gost.UsersHandlerOption(url.UserPassword("admin", "123456")), + ) + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +func wsServer() { + ln, err := gost.WSListener(":18000", nil) + if err != nil { + log.Fatal(err) + } + h := gost.HTTPHandler( + gost.UsersHandlerOption(url.UserPassword("admin", "123456")), + ) + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +func wssServer() { + ln, err := gost.WSSListener(":18443", tlsConfig(), nil) + if err != nil { + log.Fatal(err) + } + h := gost.HTTPHandler( + gost.UsersHandlerOption(url.UserPassword("admin", "123456")), + ) + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +func kcpServer() { + ln, err := gost.KCPListener(":18388", nil) + if err != nil { + log.Fatal(err) + } + h := gost.HTTPHandler() + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +func tcpForwardServer() { + ln, err := gost.TCPListener(":2222") + if err != nil { + log.Fatal(err) + } + h := gost.TCPDirectForwardHandler("localhost:22") + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +func tcpRemoteForwardServer() { + ln, err := gost.TCPRemoteForwardListener( + ":1222", + /* + gost.NewChain( + gost.Node{ + Protocol: "socks5", + Transport: "tcp", + Addr: "localhost:12345", + User: url.UserPassword("admin", "123456"), + Client: &gost.Client{ + Connector: gost.SOCKS5Connector(url.UserPassword("admin", "123456")), + Transporter: gost.TCPTransporter(), + }, + }, + ), + */ + nil, + ) + if err != nil { + log.Fatal() + } + h := gost.TCPRemoteForwardHandler( + ":22", + //gost.AddrHandlerOption("127.0.0.1:22"), + ) + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +func rudpForwardServer() { + ln, err := gost.UDPRemoteForwardListener( + ":10053", + gost.NewChain( + gost.Node{ + Protocol: "socks5", + Transport: "tcp", + Addr: "localhost:12345", + User: url.UserPassword("admin", "123456"), + Client: &gost.Client{ + Connector: gost.SOCKS5Connector(url.UserPassword("admin", "123456")), + Transporter: gost.TCPTransporter(), + }, + }, + ), + 30*time.Second, + ) + if err != nil { + log.Fatal() + } + h := gost.UDPRemoteForwardHandler("localhost:53") + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +func tcpRedirectServer() { + ln, err := gost.TCPListener(":8008") + if err != nil { + log.Fatal(err) + } + h := gost.TCPRedirectHandler() + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +func sshTunnelServer() { + ln, err := gost.SSHTunnelListener(":12222", &gost.SSHConfig{TLSConfig: tlsConfig()}) + if err != nil { + log.Fatal(err) + } + h := gost.HTTPHandler( + gost.UsersHandlerOption(url.UserPassword("admin", "123456")), + ) + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +func http2Server() { + // http2.VerboseLogs = true + + ln, err := gost.HTTP2Listener(":1443", tlsConfig()) + if err != nil { + log.Fatal(err) + } + h := gost.HTTP2Handler( + gost.UsersHandlerOption(url.UserPassword("admin", "123456")), + ) + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +func http2TunnelServer() { + ln, err := gost.H2Listener(":8443", tlsConfig()) // HTTP2 h2 mode + // ln, err := gost.H2CListener(":8443") // HTTP2 h2c mode + if err != nil { + log.Fatal(err) + } + // h := gost.HTTPHandler( + // gost.UsersHandlerOption(url.UserPassword("admin", "123456")), + // ) + h := gost.SOCKS5Handler( + gost.UsersHandlerOption(url.UserPassword("admin", "123456")), + gost.TLSConfigHandlerOption(tlsConfig()), + ) + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +func quicServer() { + ln, err := gost.QUICListener("localhost:6121", &gost.QUICConfig{TLSConfig: tlsConfig()}) + if err != nil { + log.Fatal(err) + } + h := gost.HTTPHandler( + gost.UsersHandlerOption(url.UserPassword("admin", "123456")), + ) + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +func shadowUDPServer() { + ln, err := gost.ShadowUDPListener(":18338", url.UserPassword("chacha20", "123456"), 30*time.Second) + if err != nil { + log.Fatal(err) + } + h := gost.ShadowUDPdHandler( + /* + gost.ChainHandlerOption(gost.NewChain( + gost.Node{ + Protocol: "socks5", + Transport: "tcp", + Addr: "localhost:11080", + User: url.UserPassword("admin", "123456"), + Client: &gost.Client{ + Connector: gost.SOCKS5Connector(url.UserPassword("admin", "123456")), + Transporter: gost.TCPTransporter(), + }, + }, + )), + */ + ) + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +var ( + rawCert = []byte(`-----BEGIN CERTIFICATE----- +MIIC+jCCAeKgAwIBAgIRAMlREhz8Miu1FQozsxbeqyMwDQYJKoZIhvcNAQELBQAw +EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xNzA1MTkwNTM5MDJaFw0xODA1MTkwNTM5 +MDJaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw +ggEKAoIBAQCyfqvv0kDriciEAVIW6JaWYFCL9a19jj1wmAGmVGxV3kNsr01kpa6N +0EBqnrcy7WknhCt1d43CqhKtTcXgJ/J9phZVxlizb8sUB85hm+MvP0N3HCg3f0Jw +hLuMrPijS6xjyw0fKCK/p6OUYMIfo5cdqeZid2WV4Ozts5uRd6Dmy2kyBe8Zg1F4 +8YJGuTWZmL2L7uZUiPY4T3q9+1iucq3vUpxymVRi1BTXnTpx+C0GS8NNgeEmevHv +482vHM5DNflAQ+mvGZvBVduq/AfirCDnt2DIZm1DcZXLrY9F3EPrlRZexmAhCDGR +LIKnMmoGicBM11Aw1fDIfJAHynk43tjPAgMBAAGjSzBJMA4GA1UdDwEB/wQEAwIF +oDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBQGA1UdEQQNMAuC +CWxvY2FsaG9zdDANBgkqhkiG9w0BAQsFAAOCAQEAAx8Lna8DcQv0bRB3L9i2+KRN +l/UhPCoFagxk1cZore4p0w+1m7OgigOoTpg5jh78DzVDhScZlgJ0bBVYp5rojeJS +cBDC9lCDcaXQfFmT5LykCAwIgw/gs+rw5Aq0y3D0m8CcqKosyZa9wnZ2cVy/+45w +emcSdboc65ueZScv38/W7aTUoVRcjyRUv0jv0zW0EPnnDlluVkeZo9spBhiTTwoj +b3zGODs6alTNIJwZIHNxxyOmfJPpVVp8BzGbMk7YBixSlZ/vbrrYV34TcSiy7J57 +lNNoVWM+OwiVk1+AEZfQDwaQfef5tsIkAZBUyITkkDKRhygtwM2110dejbEsgg== +-----END CERTIFICATE-----`) + rawKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAsn6r79JA64nIhAFSFuiWlmBQi/WtfY49cJgBplRsVd5DbK9N +ZKWujdBAap63Mu1pJ4QrdXeNwqoSrU3F4CfyfaYWVcZYs2/LFAfOYZvjLz9Ddxwo +N39CcIS7jKz4o0usY8sNHygiv6ejlGDCH6OXHanmYndlleDs7bObkXeg5stpMgXv +GYNRePGCRrk1mZi9i+7mVIj2OE96vftYrnKt71KccplUYtQU1506cfgtBkvDTYHh +Jnrx7+PNrxzOQzX5QEPprxmbwVXbqvwH4qwg57dgyGZtQ3GVy62PRdxD65UWXsZg +IQgxkSyCpzJqBonATNdQMNXwyHyQB8p5ON7YzwIDAQABAoIBAQCG4doj3Apa8z+n +IShbT1+cOyQi34A+xOIA151Hh7xmFxN0afRd/iWt3JUQ/OcLgQRZbDM7DSD+3W5H +r+G7xfQkpwFxx/T3g58+f7ehYx+GcJQWyhxJ88zNIkBnyb4KCAE5WBOOW9IGajPe +yE9pgUGMlPsXpYoKfHIOHg+NGY1pWUGBfBNR2kGrbkpZMmyy5bGa8dyrwAFBFRru +kcmmKvate8UlbRspFtd4nR/GQLTBrcDJ1k1i1Su/4BpDuDeK6LPI8ZRePGqbdcxk +TS30lsdYozuGfjZ5Zu8lSIJ//+7RjfDg8r684dpWjpalq8Quen60ZrIs01CSbfyU +k8gOzTHhAoGBAOKhp41wXveegq+WylSXFyngm4bzF4dVdTRsSbJVk7NaOx1vCU6o +/xIHoGEQyLI6wF+EaHmY89/Qu6tSV97XyBbiKeskopv5iXS/BsWTHJ1VbCA1ZLmK +HgGllEkS0xfc9AdB7b6/K7LxAAQVKP3DtN6+6pSDZh9Sv2M1j0DbhkNbAoGBAMmg +HcMfExaaeskjHqyLudtKX+znwaIoumleOGuavohR4R+Fpk8Yv8Xhb5U7Yr4gk0vY +CFmhp1WAi6QMZ/8jePlKKXl3Ney827luoKiMczp2DoYE0t0u2Kw3LfkNKfjADZ7d +JI6xPJV9/X1erwjq+4UdKqrpOf05SY4nkBMcvr6dAoGAXzisvbDJNiFTp5Mj0Abr +pJzKvBjHegVeCXi2PkfWlzUCQYu1zWcURO8PY7k5mik1SuzHONAbJ578Oy+N3AOt +/m9oTXRHHmHqbzMUFU+KZlDN7XqBp7NwiCCZ/Vn7d7tOjP4Wdl68baL07sI1RupD +xJNS3LOY5PBPmc+XMRkLgKECgYEAgBNDlJSCrZMHeAjlDTncn53I/VXiPD2e3BvL +vx6W9UT9ueZN1GSmPO6M0MDeYmOS7VSXSUhUYQ28pkJzNTC1QbWITu4YxP7anBnX +1/kPoQ0pAJzDzVharlqGy3M/PBHTFRzogfO3xkY35ZFlokaR6uayGcr42Q+w16nt +7RYPXEkCgYEA3GQYirHnGZuQ952jMvduqnpgkJiSnr0fa+94Rwa1pAhxHLFMo5s4 +fqZOtqKPj2s5X1JR0VCey1ilCcaAhWeb3tXCpbYLZSbMtjtqwA6LUeGY+Xdupsjw +cfWIcOfHsIm2kP+RCxEnZf1XwiN9wyJeiUKlE0dqmx9j7F0Bm+7YDhI= +-----END RSA PRIVATE KEY-----`) +) + +func tlsConfig() *tls.Config { + cert, err := tls.X509KeyPair(rawCert, rawKey) + if err != nil { + panic(err) + } + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + PreferServerCipherSuites: true, + } +} + +func testServer() { + s := &http.Server{ + Addr: ":18888", + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "abcdefghijklmnopqrstuvwxyz") + }), + } + log.Fatal(s.ListenAndServe()) +} diff --git a/examples/forward/direct/client.go b/examples/forward/direct/client.go new file mode 100644 index 0000000..c824b54 --- /dev/null +++ b/examples/forward/direct/client.go @@ -0,0 +1,34 @@ +package main + +import ( + "log" + + "github.com/ginuerzh/gost" +) + +func main() { + tcpForward() +} + +func tcpForward() { + chain := gost.NewChain( + gost.Node{ + Addr: "localhost:11222", + Client: &gost.Client{ + Connector: gost.SSHDirectForwardConnector(), + Transporter: gost.SSHForwardTransporter(), + }, + }, + ) + + ln, err := gost.TCPListener(":11800") + if err != nil { + log.Fatal(err) + } + h := gost.TCPDirectForwardHandler( + "localhost:22", + gost.ChainHandlerOption(chain), + ) + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} diff --git a/examples/forward/direct/server.go b/examples/forward/direct/server.go new file mode 100644 index 0000000..20aca6d --- /dev/null +++ b/examples/forward/direct/server.go @@ -0,0 +1,82 @@ +package main + +import ( + "crypto/tls" + "log" + + "github.com/ginuerzh/gost" +) + +func main() { + sshForwardServer() +} + +func sshForwardServer() { + ln, err := gost.TCPListener(":11222") + if err != nil { + log.Fatal(err) + } + h := gost.SSHForwardHandler( + gost.AddrHandlerOption(":11222"), + // gost.UsersHandlerOption(url.UserPassword("admin", "123456")), + gost.TLSConfigHandlerOption(tlsConfig()), + ) + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +var ( + rawCert = []byte(`-----BEGIN CERTIFICATE----- +MIIC+jCCAeKgAwIBAgIRAMlREhz8Miu1FQozsxbeqyMwDQYJKoZIhvcNAQELBQAw +EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xNzA1MTkwNTM5MDJaFw0xODA1MTkwNTM5 +MDJaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw +ggEKAoIBAQCyfqvv0kDriciEAVIW6JaWYFCL9a19jj1wmAGmVGxV3kNsr01kpa6N +0EBqnrcy7WknhCt1d43CqhKtTcXgJ/J9phZVxlizb8sUB85hm+MvP0N3HCg3f0Jw +hLuMrPijS6xjyw0fKCK/p6OUYMIfo5cdqeZid2WV4Ozts5uRd6Dmy2kyBe8Zg1F4 +8YJGuTWZmL2L7uZUiPY4T3q9+1iucq3vUpxymVRi1BTXnTpx+C0GS8NNgeEmevHv +482vHM5DNflAQ+mvGZvBVduq/AfirCDnt2DIZm1DcZXLrY9F3EPrlRZexmAhCDGR +LIKnMmoGicBM11Aw1fDIfJAHynk43tjPAgMBAAGjSzBJMA4GA1UdDwEB/wQEAwIF +oDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBQGA1UdEQQNMAuC +CWxvY2FsaG9zdDANBgkqhkiG9w0BAQsFAAOCAQEAAx8Lna8DcQv0bRB3L9i2+KRN +l/UhPCoFagxk1cZore4p0w+1m7OgigOoTpg5jh78DzVDhScZlgJ0bBVYp5rojeJS +cBDC9lCDcaXQfFmT5LykCAwIgw/gs+rw5Aq0y3D0m8CcqKosyZa9wnZ2cVy/+45w +emcSdboc65ueZScv38/W7aTUoVRcjyRUv0jv0zW0EPnnDlluVkeZo9spBhiTTwoj +b3zGODs6alTNIJwZIHNxxyOmfJPpVVp8BzGbMk7YBixSlZ/vbrrYV34TcSiy7J57 +lNNoVWM+OwiVk1+AEZfQDwaQfef5tsIkAZBUyITkkDKRhygtwM2110dejbEsgg== +-----END CERTIFICATE-----`) + rawKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAsn6r79JA64nIhAFSFuiWlmBQi/WtfY49cJgBplRsVd5DbK9N +ZKWujdBAap63Mu1pJ4QrdXeNwqoSrU3F4CfyfaYWVcZYs2/LFAfOYZvjLz9Ddxwo +N39CcIS7jKz4o0usY8sNHygiv6ejlGDCH6OXHanmYndlleDs7bObkXeg5stpMgXv +GYNRePGCRrk1mZi9i+7mVIj2OE96vftYrnKt71KccplUYtQU1506cfgtBkvDTYHh +Jnrx7+PNrxzOQzX5QEPprxmbwVXbqvwH4qwg57dgyGZtQ3GVy62PRdxD65UWXsZg +IQgxkSyCpzJqBonATNdQMNXwyHyQB8p5ON7YzwIDAQABAoIBAQCG4doj3Apa8z+n +IShbT1+cOyQi34A+xOIA151Hh7xmFxN0afRd/iWt3JUQ/OcLgQRZbDM7DSD+3W5H +r+G7xfQkpwFxx/T3g58+f7ehYx+GcJQWyhxJ88zNIkBnyb4KCAE5WBOOW9IGajPe +yE9pgUGMlPsXpYoKfHIOHg+NGY1pWUGBfBNR2kGrbkpZMmyy5bGa8dyrwAFBFRru +kcmmKvate8UlbRspFtd4nR/GQLTBrcDJ1k1i1Su/4BpDuDeK6LPI8ZRePGqbdcxk +TS30lsdYozuGfjZ5Zu8lSIJ//+7RjfDg8r684dpWjpalq8Quen60ZrIs01CSbfyU +k8gOzTHhAoGBAOKhp41wXveegq+WylSXFyngm4bzF4dVdTRsSbJVk7NaOx1vCU6o +/xIHoGEQyLI6wF+EaHmY89/Qu6tSV97XyBbiKeskopv5iXS/BsWTHJ1VbCA1ZLmK +HgGllEkS0xfc9AdB7b6/K7LxAAQVKP3DtN6+6pSDZh9Sv2M1j0DbhkNbAoGBAMmg +HcMfExaaeskjHqyLudtKX+znwaIoumleOGuavohR4R+Fpk8Yv8Xhb5U7Yr4gk0vY +CFmhp1WAi6QMZ/8jePlKKXl3Ney827luoKiMczp2DoYE0t0u2Kw3LfkNKfjADZ7d +JI6xPJV9/X1erwjq+4UdKqrpOf05SY4nkBMcvr6dAoGAXzisvbDJNiFTp5Mj0Abr +pJzKvBjHegVeCXi2PkfWlzUCQYu1zWcURO8PY7k5mik1SuzHONAbJ578Oy+N3AOt +/m9oTXRHHmHqbzMUFU+KZlDN7XqBp7NwiCCZ/Vn7d7tOjP4Wdl68baL07sI1RupD +xJNS3LOY5PBPmc+XMRkLgKECgYEAgBNDlJSCrZMHeAjlDTncn53I/VXiPD2e3BvL +vx6W9UT9ueZN1GSmPO6M0MDeYmOS7VSXSUhUYQ28pkJzNTC1QbWITu4YxP7anBnX +1/kPoQ0pAJzDzVharlqGy3M/PBHTFRzogfO3xkY35ZFlokaR6uayGcr42Q+w16nt +7RYPXEkCgYEA3GQYirHnGZuQ952jMvduqnpgkJiSnr0fa+94Rwa1pAhxHLFMo5s4 +fqZOtqKPj2s5X1JR0VCey1ilCcaAhWeb3tXCpbYLZSbMtjtqwA6LUeGY+Xdupsjw +cfWIcOfHsIm2kP+RCxEnZf1XwiN9wyJeiUKlE0dqmx9j7F0Bm+7YDhI= +-----END RSA PRIVATE KEY-----`) +) + +func tlsConfig() *tls.Config { + cert, err := tls.X509KeyPair(rawCert, rawKey) + if err != nil { + panic(err) + } + return &tls.Config{Certificates: []tls.Certificate{cert}} +} diff --git a/examples/forward/remote/client.go b/examples/forward/remote/client.go new file mode 100644 index 0000000..68f1737 --- /dev/null +++ b/examples/forward/remote/client.go @@ -0,0 +1,35 @@ +package main + +import ( + "log" + + "github.com/ginuerzh/gost" +) + +func main() { + sshRemoteForward() +} + +func sshRemoteForward() { + chain := gost.NewChain( + gost.Node{ + Protocol: "forward", + Transport: "ssh", + Addr: "localhost:11222", + Client: &gost.Client{ + Connector: gost.SSHRemoteForwardConnector(), + Transporter: gost.SSHForwardTransporter(), + }, + }, + ) + + ln, err := gost.TCPRemoteForwardListener(":11800", chain) + if err != nil { + log.Fatal(err) + } + h := gost.TCPRemoteForwardHandler( + "localhost:10000", + ) + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} diff --git a/examples/forward/remote/server.go b/examples/forward/remote/server.go new file mode 100644 index 0000000..cc83aa8 --- /dev/null +++ b/examples/forward/remote/server.go @@ -0,0 +1,82 @@ +package main + +import ( + "crypto/tls" + "log" + + "github.com/ginuerzh/gost" +) + +func main() { + sshRemoteForwardServer() +} + +func sshRemoteForwardServer() { + ln, err := gost.TCPListener(":11222") + if err != nil { + log.Fatal(err) + } + h := gost.SSHForwardHandler( + gost.AddrHandlerOption(":11222"), + // gost.UsersHandlerOption(url.UserPassword("admin", "123456")), + gost.TLSConfigHandlerOption(tlsConfig()), + ) + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +var ( + rawCert = []byte(`-----BEGIN CERTIFICATE----- +MIIC+jCCAeKgAwIBAgIRAMlREhz8Miu1FQozsxbeqyMwDQYJKoZIhvcNAQELBQAw +EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xNzA1MTkwNTM5MDJaFw0xODA1MTkwNTM5 +MDJaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw +ggEKAoIBAQCyfqvv0kDriciEAVIW6JaWYFCL9a19jj1wmAGmVGxV3kNsr01kpa6N +0EBqnrcy7WknhCt1d43CqhKtTcXgJ/J9phZVxlizb8sUB85hm+MvP0N3HCg3f0Jw +hLuMrPijS6xjyw0fKCK/p6OUYMIfo5cdqeZid2WV4Ozts5uRd6Dmy2kyBe8Zg1F4 +8YJGuTWZmL2L7uZUiPY4T3q9+1iucq3vUpxymVRi1BTXnTpx+C0GS8NNgeEmevHv +482vHM5DNflAQ+mvGZvBVduq/AfirCDnt2DIZm1DcZXLrY9F3EPrlRZexmAhCDGR +LIKnMmoGicBM11Aw1fDIfJAHynk43tjPAgMBAAGjSzBJMA4GA1UdDwEB/wQEAwIF +oDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBQGA1UdEQQNMAuC +CWxvY2FsaG9zdDANBgkqhkiG9w0BAQsFAAOCAQEAAx8Lna8DcQv0bRB3L9i2+KRN +l/UhPCoFagxk1cZore4p0w+1m7OgigOoTpg5jh78DzVDhScZlgJ0bBVYp5rojeJS +cBDC9lCDcaXQfFmT5LykCAwIgw/gs+rw5Aq0y3D0m8CcqKosyZa9wnZ2cVy/+45w +emcSdboc65ueZScv38/W7aTUoVRcjyRUv0jv0zW0EPnnDlluVkeZo9spBhiTTwoj +b3zGODs6alTNIJwZIHNxxyOmfJPpVVp8BzGbMk7YBixSlZ/vbrrYV34TcSiy7J57 +lNNoVWM+OwiVk1+AEZfQDwaQfef5tsIkAZBUyITkkDKRhygtwM2110dejbEsgg== +-----END CERTIFICATE-----`) + rawKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAsn6r79JA64nIhAFSFuiWlmBQi/WtfY49cJgBplRsVd5DbK9N +ZKWujdBAap63Mu1pJ4QrdXeNwqoSrU3F4CfyfaYWVcZYs2/LFAfOYZvjLz9Ddxwo +N39CcIS7jKz4o0usY8sNHygiv6ejlGDCH6OXHanmYndlleDs7bObkXeg5stpMgXv +GYNRePGCRrk1mZi9i+7mVIj2OE96vftYrnKt71KccplUYtQU1506cfgtBkvDTYHh +Jnrx7+PNrxzOQzX5QEPprxmbwVXbqvwH4qwg57dgyGZtQ3GVy62PRdxD65UWXsZg +IQgxkSyCpzJqBonATNdQMNXwyHyQB8p5ON7YzwIDAQABAoIBAQCG4doj3Apa8z+n +IShbT1+cOyQi34A+xOIA151Hh7xmFxN0afRd/iWt3JUQ/OcLgQRZbDM7DSD+3W5H +r+G7xfQkpwFxx/T3g58+f7ehYx+GcJQWyhxJ88zNIkBnyb4KCAE5WBOOW9IGajPe +yE9pgUGMlPsXpYoKfHIOHg+NGY1pWUGBfBNR2kGrbkpZMmyy5bGa8dyrwAFBFRru +kcmmKvate8UlbRspFtd4nR/GQLTBrcDJ1k1i1Su/4BpDuDeK6LPI8ZRePGqbdcxk +TS30lsdYozuGfjZ5Zu8lSIJ//+7RjfDg8r684dpWjpalq8Quen60ZrIs01CSbfyU +k8gOzTHhAoGBAOKhp41wXveegq+WylSXFyngm4bzF4dVdTRsSbJVk7NaOx1vCU6o +/xIHoGEQyLI6wF+EaHmY89/Qu6tSV97XyBbiKeskopv5iXS/BsWTHJ1VbCA1ZLmK +HgGllEkS0xfc9AdB7b6/K7LxAAQVKP3DtN6+6pSDZh9Sv2M1j0DbhkNbAoGBAMmg +HcMfExaaeskjHqyLudtKX+znwaIoumleOGuavohR4R+Fpk8Yv8Xhb5U7Yr4gk0vY +CFmhp1WAi6QMZ/8jePlKKXl3Ney827luoKiMczp2DoYE0t0u2Kw3LfkNKfjADZ7d +JI6xPJV9/X1erwjq+4UdKqrpOf05SY4nkBMcvr6dAoGAXzisvbDJNiFTp5Mj0Abr +pJzKvBjHegVeCXi2PkfWlzUCQYu1zWcURO8PY7k5mik1SuzHONAbJ578Oy+N3AOt +/m9oTXRHHmHqbzMUFU+KZlDN7XqBp7NwiCCZ/Vn7d7tOjP4Wdl68baL07sI1RupD +xJNS3LOY5PBPmc+XMRkLgKECgYEAgBNDlJSCrZMHeAjlDTncn53I/VXiPD2e3BvL +vx6W9UT9ueZN1GSmPO6M0MDeYmOS7VSXSUhUYQ28pkJzNTC1QbWITu4YxP7anBnX +1/kPoQ0pAJzDzVharlqGy3M/PBHTFRzogfO3xkY35ZFlokaR6uayGcr42Q+w16nt +7RYPXEkCgYEA3GQYirHnGZuQ952jMvduqnpgkJiSnr0fa+94Rwa1pAhxHLFMo5s4 +fqZOtqKPj2s5X1JR0VCey1ilCcaAhWeb3tXCpbYLZSbMtjtqwA6LUeGY+Xdupsjw +cfWIcOfHsIm2kP+RCxEnZf1XwiN9wyJeiUKlE0dqmx9j7F0Bm+7YDhI= +-----END RSA PRIVATE KEY-----`) +) + +func tlsConfig() *tls.Config { + cert, err := tls.X509KeyPair(rawCert, rawKey) + if err != nil { + panic(err) + } + return &tls.Config{Certificates: []tls.Certificate{cert}} +} diff --git a/examples/forward/udp/cli.go b/examples/forward/udp/cli.go new file mode 100644 index 0000000..ff9c823 --- /dev/null +++ b/examples/forward/udp/cli.go @@ -0,0 +1,53 @@ +package main + +import ( + "flag" + "log" + "net" + "time" +) + +var ( + concurrency int + saddr string +) + +func init() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + flag.StringVar(&saddr, "S", ":18080", "server address") + flag.IntVar(&concurrency, "c", 1, "Number of multiple echo to make at a time") + flag.Parse() +} + +func main() { + for i := 0; i < concurrency; i++ { + go udpEchoLoop() + } + select {} +} + +func udpEchoLoop() { + addr, err := net.ResolveUDPAddr("udp", saddr) + if err != nil { + log.Fatal(err) + } + conn, err := net.DialUDP("udp", nil, addr) + if err != nil { + log.Fatal(err) + } + + msg := []byte(`abcdefghijklmnopqrstuvwxyz`) + for { + if _, err := conn.Write(msg); err != nil { + log.Fatal(err) + } + b := make([]byte, 1024) + _, err := conn.Read(b) + if err != nil { + log.Fatal(err) + } + // log.Println(string(b[:n])) + time.Sleep(100 * time.Millisecond) + } +} diff --git a/examples/forward/udp/direct.go b/examples/forward/udp/direct.go new file mode 100644 index 0000000..79f9a35 --- /dev/null +++ b/examples/forward/udp/direct.go @@ -0,0 +1,57 @@ +package main + +import ( + "flag" + "log" + "time" + + "github.com/ginuerzh/gost" +) + +var ( + laddr, faddr string + quiet bool +) + +func init() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + flag.StringVar(&laddr, "L", ":18080", "listen address") + flag.StringVar(&faddr, "F", ":8080", "forward address") + flag.BoolVar(&quiet, "q", false, "quiet mode") + flag.BoolVar(&gost.Debug, "d", false, "debug mode") + flag.Parse() + + if quiet { + gost.SetLogger(&gost.NopLogger{}) + } +} +func main() { + udpDirectForwardServer() +} + +func udpDirectForwardServer() { + ln, err := gost.UDPDirectForwardListener(laddr, time.Second*30) + if err != nil { + log.Fatal(err) + } + h := gost.UDPDirectForwardHandler( + faddr, + /* + gost.ChainHandlerOption(gost.NewChain(gost.Node{ + Protocol: "socks5", + Transport: "tcp", + Addr: ":11080", + User: url.UserPassword("admin", "123456"), + Client: &gost.Client{ + Connector: gost.SOCKS5Connector( + url.UserPassword("admin", "123456"), + ), + Transporter: gost.TCPTransporter(), + }, + })), + */ + ) + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} diff --git a/examples/forward/udp/remote.go b/examples/forward/udp/remote.go new file mode 100644 index 0000000..b0c4d50 --- /dev/null +++ b/examples/forward/udp/remote.go @@ -0,0 +1,60 @@ +package main + +import ( + "flag" + "log" + "time" + + "github.com/ginuerzh/gost" +) + +var ( + laddr, faddr string + quiet bool +) + +func init() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + flag.StringVar(&laddr, "L", ":18080", "listen address") + flag.StringVar(&faddr, "F", ":8080", "forward address") + flag.BoolVar(&quiet, "q", false, "quiet mode") + flag.BoolVar(&gost.Debug, "d", false, "debug mode") + flag.Parse() + + if quiet { + gost.SetLogger(&gost.NopLogger{}) + } +} +func main() { + udpRemoteForwardServer() +} + +func udpRemoteForwardServer() { + ln, err := gost.UDPRemoteForwardListener( + laddr, + /* + gost.NewChain(gost.Node{ + Protocol: "socks5", + Transport: "tcp", + Addr: ":11080", + User: url.UserPassword("admin", "123456"), + Client: &gost.Client{ + Connector: gost.SOCKS5Connector( + url.UserPassword("admin", "123456"), + ), + Transporter: gost.TCPTransporter(), + }, + }), + */ + nil, + time.Second*30) + if err != nil { + log.Fatal(err) + } + h := gost.UDPRemoteForwardHandler( + faddr, + ) + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} diff --git a/examples/forward/udp/srv.go b/examples/forward/udp/srv.go new file mode 100644 index 0000000..3aadf2d --- /dev/null +++ b/examples/forward/udp/srv.go @@ -0,0 +1,44 @@ +package main + +import ( + "flag" + "log" + "net" +) + +var ( + laddr string +) + +func init() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + flag.StringVar(&laddr, "L", ":8080", "listen address") + flag.Parse() +} +func main() { + udpEchoServer() +} + +func udpEchoServer() { + addr, err := net.ResolveUDPAddr("udp", laddr) + if err != nil { + log.Fatal(err) + } + conn, err := net.ListenUDP("udp", addr) + if err != nil { + log.Fatal(err) + } + + for { + b := make([]byte, 1024) + n, raddr, err := conn.ReadFromUDP(b) + if err != nil { + log.Fatal(err) + } + if _, err = conn.WriteToUDP(b[:n], raddr); err != nil { + log.Fatal(err) + } + + } +} diff --git a/examples/http2/http2.go b/examples/http2/http2.go new file mode 100644 index 0000000..369fdfe --- /dev/null +++ b/examples/http2/http2.go @@ -0,0 +1,125 @@ +package main + +import ( + "crypto/tls" + "flag" + "log" + "net/url" + + "golang.org/x/net/http2" + + "github.com/ginuerzh/gost" +) + +var ( + quiet bool + keyFile, certFile string + laddr string + user, passwd string +) + +func init() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + flag.StringVar(&laddr, "L", ":1443", "listen address") + flag.StringVar(&user, "u", "", "username") + flag.StringVar(&passwd, "p", "", "password") + flag.BoolVar(&quiet, "q", false, "quiet mode") + flag.BoolVar(&gost.Debug, "d", false, "debug mode") + flag.BoolVar(&http2.VerboseLogs, "v", false, "HTTP2 verbose log") + flag.StringVar(&keyFile, "key", "key.pem", "TLS key file") + flag.StringVar(&certFile, "cert", "cert.pem", "TLS cert file") + flag.Parse() + + if quiet { + gost.SetLogger(&gost.NopLogger{}) + } +} + +func main() { + http2Server() +} + +func http2Server() { + cert, er := tls.LoadX509KeyPair(certFile, keyFile) + if er != nil { + log.Println(er) + cert, er = tls.X509KeyPair(rawCert, rawKey) + if er != nil { + panic(er) + } + } + + ln, err := gost.HTTP2Listener(laddr, &tls.Config{Certificates: []tls.Certificate{cert}}) + if err != nil { + log.Fatal(err) + } + + var users []*url.Userinfo + if user != "" || passwd != "" { + users = append(users, url.UserPassword(user, passwd)) + } + + h := gost.HTTP2Handler( + gost.UsersHandlerOption(users...), + gost.AddrHandlerOption(laddr), + ) + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +var ( + rawCert = []byte(`-----BEGIN CERTIFICATE----- +MIIC+jCCAeKgAwIBAgIRAMlREhz8Miu1FQozsxbeqyMwDQYJKoZIhvcNAQELBQAw +EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xNzA1MTkwNTM5MDJaFw0xODA1MTkwNTM5 +MDJaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw +ggEKAoIBAQCyfqvv0kDriciEAVIW6JaWYFCL9a19jj1wmAGmVGxV3kNsr01kpa6N +0EBqnrcy7WknhCt1d43CqhKtTcXgJ/J9phZVxlizb8sUB85hm+MvP0N3HCg3f0Jw +hLuMrPijS6xjyw0fKCK/p6OUYMIfo5cdqeZid2WV4Ozts5uRd6Dmy2kyBe8Zg1F4 +8YJGuTWZmL2L7uZUiPY4T3q9+1iucq3vUpxymVRi1BTXnTpx+C0GS8NNgeEmevHv +482vHM5DNflAQ+mvGZvBVduq/AfirCDnt2DIZm1DcZXLrY9F3EPrlRZexmAhCDGR +LIKnMmoGicBM11Aw1fDIfJAHynk43tjPAgMBAAGjSzBJMA4GA1UdDwEB/wQEAwIF +oDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBQGA1UdEQQNMAuC +CWxvY2FsaG9zdDANBgkqhkiG9w0BAQsFAAOCAQEAAx8Lna8DcQv0bRB3L9i2+KRN +l/UhPCoFagxk1cZore4p0w+1m7OgigOoTpg5jh78DzVDhScZlgJ0bBVYp5rojeJS +cBDC9lCDcaXQfFmT5LykCAwIgw/gs+rw5Aq0y3D0m8CcqKosyZa9wnZ2cVy/+45w +emcSdboc65ueZScv38/W7aTUoVRcjyRUv0jv0zW0EPnnDlluVkeZo9spBhiTTwoj +b3zGODs6alTNIJwZIHNxxyOmfJPpVVp8BzGbMk7YBixSlZ/vbrrYV34TcSiy7J57 +lNNoVWM+OwiVk1+AEZfQDwaQfef5tsIkAZBUyITkkDKRhygtwM2110dejbEsgg== +-----END CERTIFICATE-----`) + rawKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAsn6r79JA64nIhAFSFuiWlmBQi/WtfY49cJgBplRsVd5DbK9N +ZKWujdBAap63Mu1pJ4QrdXeNwqoSrU3F4CfyfaYWVcZYs2/LFAfOYZvjLz9Ddxwo +N39CcIS7jKz4o0usY8sNHygiv6ejlGDCH6OXHanmYndlleDs7bObkXeg5stpMgXv +GYNRePGCRrk1mZi9i+7mVIj2OE96vftYrnKt71KccplUYtQU1506cfgtBkvDTYHh +Jnrx7+PNrxzOQzX5QEPprxmbwVXbqvwH4qwg57dgyGZtQ3GVy62PRdxD65UWXsZg +IQgxkSyCpzJqBonATNdQMNXwyHyQB8p5ON7YzwIDAQABAoIBAQCG4doj3Apa8z+n +IShbT1+cOyQi34A+xOIA151Hh7xmFxN0afRd/iWt3JUQ/OcLgQRZbDM7DSD+3W5H +r+G7xfQkpwFxx/T3g58+f7ehYx+GcJQWyhxJ88zNIkBnyb4KCAE5WBOOW9IGajPe +yE9pgUGMlPsXpYoKfHIOHg+NGY1pWUGBfBNR2kGrbkpZMmyy5bGa8dyrwAFBFRru +kcmmKvate8UlbRspFtd4nR/GQLTBrcDJ1k1i1Su/4BpDuDeK6LPI8ZRePGqbdcxk +TS30lsdYozuGfjZ5Zu8lSIJ//+7RjfDg8r684dpWjpalq8Quen60ZrIs01CSbfyU +k8gOzTHhAoGBAOKhp41wXveegq+WylSXFyngm4bzF4dVdTRsSbJVk7NaOx1vCU6o +/xIHoGEQyLI6wF+EaHmY89/Qu6tSV97XyBbiKeskopv5iXS/BsWTHJ1VbCA1ZLmK +HgGllEkS0xfc9AdB7b6/K7LxAAQVKP3DtN6+6pSDZh9Sv2M1j0DbhkNbAoGBAMmg +HcMfExaaeskjHqyLudtKX+znwaIoumleOGuavohR4R+Fpk8Yv8Xhb5U7Yr4gk0vY +CFmhp1WAi6QMZ/8jePlKKXl3Ney827luoKiMczp2DoYE0t0u2Kw3LfkNKfjADZ7d +JI6xPJV9/X1erwjq+4UdKqrpOf05SY4nkBMcvr6dAoGAXzisvbDJNiFTp5Mj0Abr +pJzKvBjHegVeCXi2PkfWlzUCQYu1zWcURO8PY7k5mik1SuzHONAbJ578Oy+N3AOt +/m9oTXRHHmHqbzMUFU+KZlDN7XqBp7NwiCCZ/Vn7d7tOjP4Wdl68baL07sI1RupD +xJNS3LOY5PBPmc+XMRkLgKECgYEAgBNDlJSCrZMHeAjlDTncn53I/VXiPD2e3BvL +vx6W9UT9ueZN1GSmPO6M0MDeYmOS7VSXSUhUYQ28pkJzNTC1QbWITu4YxP7anBnX +1/kPoQ0pAJzDzVharlqGy3M/PBHTFRzogfO3xkY35ZFlokaR6uayGcr42Q+w16nt +7RYPXEkCgYEA3GQYirHnGZuQ952jMvduqnpgkJiSnr0fa+94Rwa1pAhxHLFMo5s4 +fqZOtqKPj2s5X1JR0VCey1ilCcaAhWeb3tXCpbYLZSbMtjtqwA6LUeGY+Xdupsjw +cfWIcOfHsIm2kP+RCxEnZf1XwiN9wyJeiUKlE0dqmx9j7F0Bm+7YDhI= +-----END RSA PRIVATE KEY-----`) +) + +func tlsConfig() *tls.Config { + cert, err := tls.X509KeyPair(rawCert, rawKey) + if err != nil { + panic(err) + } + return &tls.Config{Certificates: []tls.Certificate{cert}} +} diff --git a/examples/quic/quicc.go b/examples/quic/quicc.go new file mode 100644 index 0000000..119e9e6 --- /dev/null +++ b/examples/quic/quicc.go @@ -0,0 +1,110 @@ +package main + +import ( + "crypto/tls" + "flag" + "log" + "time" + + "github.com/ginuerzh/gost" +) + +var ( + laddr, faddr string + quiet bool +) + +func init() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + flag.StringVar(&laddr, "L", ":18080", "listen address") + flag.StringVar(&faddr, "F", "localhost:6121", "forward address") + flag.BoolVar(&quiet, "q", false, "quiet mode") + flag.BoolVar(&gost.Debug, "d", false, "debug mode") + flag.Parse() + + if quiet { + gost.SetLogger(&gost.NopLogger{}) + } +} + +func main() { + chain := gost.NewChain( + gost.Node{ + Protocol: "socks5", + Transport: "quic", + Addr: faddr, + Client: &gost.Client{ + Connector: gost.SOCKS5Connector(nil), + Transporter: gost.QUICTransporter(&gost.QUICConfig{Timeout: 30 * time.Second, KeepAlive: true}), + }, + }, + ) + + ln, err := gost.TCPListener(laddr) + if err != nil { + log.Fatal(err) + } + h := gost.SOCKS5Handler( + gost.ChainHandlerOption(chain), + gost.TLSConfigHandlerOption(tlsConfig()), + ) + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +var ( + rawCert = []byte(`-----BEGIN CERTIFICATE----- +MIIC+jCCAeKgAwIBAgIRAMlREhz8Miu1FQozsxbeqyMwDQYJKoZIhvcNAQELBQAw +EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xNzA1MTkwNTM5MDJaFw0xODA1MTkwNTM5 +MDJaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw +ggEKAoIBAQCyfqvv0kDriciEAVIW6JaWYFCL9a19jj1wmAGmVGxV3kNsr01kpa6N +0EBqnrcy7WknhCt1d43CqhKtTcXgJ/J9phZVxlizb8sUB85hm+MvP0N3HCg3f0Jw +hLuMrPijS6xjyw0fKCK/p6OUYMIfo5cdqeZid2WV4Ozts5uRd6Dmy2kyBe8Zg1F4 +8YJGuTWZmL2L7uZUiPY4T3q9+1iucq3vUpxymVRi1BTXnTpx+C0GS8NNgeEmevHv +482vHM5DNflAQ+mvGZvBVduq/AfirCDnt2DIZm1DcZXLrY9F3EPrlRZexmAhCDGR +LIKnMmoGicBM11Aw1fDIfJAHynk43tjPAgMBAAGjSzBJMA4GA1UdDwEB/wQEAwIF +oDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBQGA1UdEQQNMAuC +CWxvY2FsaG9zdDANBgkqhkiG9w0BAQsFAAOCAQEAAx8Lna8DcQv0bRB3L9i2+KRN +l/UhPCoFagxk1cZore4p0w+1m7OgigOoTpg5jh78DzVDhScZlgJ0bBVYp5rojeJS +cBDC9lCDcaXQfFmT5LykCAwIgw/gs+rw5Aq0y3D0m8CcqKosyZa9wnZ2cVy/+45w +emcSdboc65ueZScv38/W7aTUoVRcjyRUv0jv0zW0EPnnDlluVkeZo9spBhiTTwoj +b3zGODs6alTNIJwZIHNxxyOmfJPpVVp8BzGbMk7YBixSlZ/vbrrYV34TcSiy7J57 +lNNoVWM+OwiVk1+AEZfQDwaQfef5tsIkAZBUyITkkDKRhygtwM2110dejbEsgg== +-----END CERTIFICATE-----`) + rawKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAsn6r79JA64nIhAFSFuiWlmBQi/WtfY49cJgBplRsVd5DbK9N +ZKWujdBAap63Mu1pJ4QrdXeNwqoSrU3F4CfyfaYWVcZYs2/LFAfOYZvjLz9Ddxwo +N39CcIS7jKz4o0usY8sNHygiv6ejlGDCH6OXHanmYndlleDs7bObkXeg5stpMgXv +GYNRePGCRrk1mZi9i+7mVIj2OE96vftYrnKt71KccplUYtQU1506cfgtBkvDTYHh +Jnrx7+PNrxzOQzX5QEPprxmbwVXbqvwH4qwg57dgyGZtQ3GVy62PRdxD65UWXsZg +IQgxkSyCpzJqBonATNdQMNXwyHyQB8p5ON7YzwIDAQABAoIBAQCG4doj3Apa8z+n +IShbT1+cOyQi34A+xOIA151Hh7xmFxN0afRd/iWt3JUQ/OcLgQRZbDM7DSD+3W5H +r+G7xfQkpwFxx/T3g58+f7ehYx+GcJQWyhxJ88zNIkBnyb4KCAE5WBOOW9IGajPe +yE9pgUGMlPsXpYoKfHIOHg+NGY1pWUGBfBNR2kGrbkpZMmyy5bGa8dyrwAFBFRru +kcmmKvate8UlbRspFtd4nR/GQLTBrcDJ1k1i1Su/4BpDuDeK6LPI8ZRePGqbdcxk +TS30lsdYozuGfjZ5Zu8lSIJ//+7RjfDg8r684dpWjpalq8Quen60ZrIs01CSbfyU +k8gOzTHhAoGBAOKhp41wXveegq+WylSXFyngm4bzF4dVdTRsSbJVk7NaOx1vCU6o +/xIHoGEQyLI6wF+EaHmY89/Qu6tSV97XyBbiKeskopv5iXS/BsWTHJ1VbCA1ZLmK +HgGllEkS0xfc9AdB7b6/K7LxAAQVKP3DtN6+6pSDZh9Sv2M1j0DbhkNbAoGBAMmg +HcMfExaaeskjHqyLudtKX+znwaIoumleOGuavohR4R+Fpk8Yv8Xhb5U7Yr4gk0vY +CFmhp1WAi6QMZ/8jePlKKXl3Ney827luoKiMczp2DoYE0t0u2Kw3LfkNKfjADZ7d +JI6xPJV9/X1erwjq+4UdKqrpOf05SY4nkBMcvr6dAoGAXzisvbDJNiFTp5Mj0Abr +pJzKvBjHegVeCXi2PkfWlzUCQYu1zWcURO8PY7k5mik1SuzHONAbJ578Oy+N3AOt +/m9oTXRHHmHqbzMUFU+KZlDN7XqBp7NwiCCZ/Vn7d7tOjP4Wdl68baL07sI1RupD +xJNS3LOY5PBPmc+XMRkLgKECgYEAgBNDlJSCrZMHeAjlDTncn53I/VXiPD2e3BvL +vx6W9UT9ueZN1GSmPO6M0MDeYmOS7VSXSUhUYQ28pkJzNTC1QbWITu4YxP7anBnX +1/kPoQ0pAJzDzVharlqGy3M/PBHTFRzogfO3xkY35ZFlokaR6uayGcr42Q+w16nt +7RYPXEkCgYEA3GQYirHnGZuQ952jMvduqnpgkJiSnr0fa+94Rwa1pAhxHLFMo5s4 +fqZOtqKPj2s5X1JR0VCey1ilCcaAhWeb3tXCpbYLZSbMtjtqwA6LUeGY+Xdupsjw +cfWIcOfHsIm2kP+RCxEnZf1XwiN9wyJeiUKlE0dqmx9j7F0Bm+7YDhI= +-----END RSA PRIVATE KEY-----`) +) + +func tlsConfig() *tls.Config { + cert, err := tls.X509KeyPair(rawCert, rawKey) + if err != nil { + panic(err) + } + return &tls.Config{Certificates: []tls.Certificate{cert}} +} diff --git a/examples/quic/quics.go b/examples/quic/quics.go new file mode 100644 index 0000000..246995a --- /dev/null +++ b/examples/quic/quics.go @@ -0,0 +1,100 @@ +package main + +import ( + "crypto/tls" + "flag" + "log" + + "github.com/ginuerzh/gost" +) + +var ( + laddr string + quiet bool +) + +func init() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + flag.StringVar(&laddr, "L", ":6121", "listen address") + flag.BoolVar(&quiet, "q", false, "quiet mode") + flag.BoolVar(&gost.Debug, "d", false, "debug mode") + + flag.Parse() + + if quiet { + gost.SetLogger(&gost.NopLogger{}) + } +} + +func main() { + quicServer() +} + +func quicServer() { + ln, err := gost.QUICListener(laddr, &gost.QUICConfig{TLSConfig: tlsConfig()}) + if err != nil { + log.Fatal(err) + } + h := gost.SOCKS5Handler(gost.TLSConfigHandlerOption(tlsConfig())) + log.Println("server listen on", laddr) + + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +var ( + rawCert = []byte(`-----BEGIN CERTIFICATE----- +MIIC+jCCAeKgAwIBAgIRAMlREhz8Miu1FQozsxbeqyMwDQYJKoZIhvcNAQELBQAw +EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xNzA1MTkwNTM5MDJaFw0xODA1MTkwNTM5 +MDJaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw +ggEKAoIBAQCyfqvv0kDriciEAVIW6JaWYFCL9a19jj1wmAGmVGxV3kNsr01kpa6N +0EBqnrcy7WknhCt1d43CqhKtTcXgJ/J9phZVxlizb8sUB85hm+MvP0N3HCg3f0Jw +hLuMrPijS6xjyw0fKCK/p6OUYMIfo5cdqeZid2WV4Ozts5uRd6Dmy2kyBe8Zg1F4 +8YJGuTWZmL2L7uZUiPY4T3q9+1iucq3vUpxymVRi1BTXnTpx+C0GS8NNgeEmevHv +482vHM5DNflAQ+mvGZvBVduq/AfirCDnt2DIZm1DcZXLrY9F3EPrlRZexmAhCDGR +LIKnMmoGicBM11Aw1fDIfJAHynk43tjPAgMBAAGjSzBJMA4GA1UdDwEB/wQEAwIF +oDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBQGA1UdEQQNMAuC +CWxvY2FsaG9zdDANBgkqhkiG9w0BAQsFAAOCAQEAAx8Lna8DcQv0bRB3L9i2+KRN +l/UhPCoFagxk1cZore4p0w+1m7OgigOoTpg5jh78DzVDhScZlgJ0bBVYp5rojeJS +cBDC9lCDcaXQfFmT5LykCAwIgw/gs+rw5Aq0y3D0m8CcqKosyZa9wnZ2cVy/+45w +emcSdboc65ueZScv38/W7aTUoVRcjyRUv0jv0zW0EPnnDlluVkeZo9spBhiTTwoj +b3zGODs6alTNIJwZIHNxxyOmfJPpVVp8BzGbMk7YBixSlZ/vbrrYV34TcSiy7J57 +lNNoVWM+OwiVk1+AEZfQDwaQfef5tsIkAZBUyITkkDKRhygtwM2110dejbEsgg== +-----END CERTIFICATE-----`) + rawKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAsn6r79JA64nIhAFSFuiWlmBQi/WtfY49cJgBplRsVd5DbK9N +ZKWujdBAap63Mu1pJ4QrdXeNwqoSrU3F4CfyfaYWVcZYs2/LFAfOYZvjLz9Ddxwo +N39CcIS7jKz4o0usY8sNHygiv6ejlGDCH6OXHanmYndlleDs7bObkXeg5stpMgXv +GYNRePGCRrk1mZi9i+7mVIj2OE96vftYrnKt71KccplUYtQU1506cfgtBkvDTYHh +Jnrx7+PNrxzOQzX5QEPprxmbwVXbqvwH4qwg57dgyGZtQ3GVy62PRdxD65UWXsZg +IQgxkSyCpzJqBonATNdQMNXwyHyQB8p5ON7YzwIDAQABAoIBAQCG4doj3Apa8z+n +IShbT1+cOyQi34A+xOIA151Hh7xmFxN0afRd/iWt3JUQ/OcLgQRZbDM7DSD+3W5H +r+G7xfQkpwFxx/T3g58+f7ehYx+GcJQWyhxJ88zNIkBnyb4KCAE5WBOOW9IGajPe +yE9pgUGMlPsXpYoKfHIOHg+NGY1pWUGBfBNR2kGrbkpZMmyy5bGa8dyrwAFBFRru +kcmmKvate8UlbRspFtd4nR/GQLTBrcDJ1k1i1Su/4BpDuDeK6LPI8ZRePGqbdcxk +TS30lsdYozuGfjZ5Zu8lSIJ//+7RjfDg8r684dpWjpalq8Quen60ZrIs01CSbfyU +k8gOzTHhAoGBAOKhp41wXveegq+WylSXFyngm4bzF4dVdTRsSbJVk7NaOx1vCU6o +/xIHoGEQyLI6wF+EaHmY89/Qu6tSV97XyBbiKeskopv5iXS/BsWTHJ1VbCA1ZLmK +HgGllEkS0xfc9AdB7b6/K7LxAAQVKP3DtN6+6pSDZh9Sv2M1j0DbhkNbAoGBAMmg +HcMfExaaeskjHqyLudtKX+znwaIoumleOGuavohR4R+Fpk8Yv8Xhb5U7Yr4gk0vY +CFmhp1WAi6QMZ/8jePlKKXl3Ney827luoKiMczp2DoYE0t0u2Kw3LfkNKfjADZ7d +JI6xPJV9/X1erwjq+4UdKqrpOf05SY4nkBMcvr6dAoGAXzisvbDJNiFTp5Mj0Abr +pJzKvBjHegVeCXi2PkfWlzUCQYu1zWcURO8PY7k5mik1SuzHONAbJ578Oy+N3AOt +/m9oTXRHHmHqbzMUFU+KZlDN7XqBp7NwiCCZ/Vn7d7tOjP4Wdl68baL07sI1RupD +xJNS3LOY5PBPmc+XMRkLgKECgYEAgBNDlJSCrZMHeAjlDTncn53I/VXiPD2e3BvL +vx6W9UT9ueZN1GSmPO6M0MDeYmOS7VSXSUhUYQ28pkJzNTC1QbWITu4YxP7anBnX +1/kPoQ0pAJzDzVharlqGy3M/PBHTFRzogfO3xkY35ZFlokaR6uayGcr42Q+w16nt +7RYPXEkCgYEA3GQYirHnGZuQ952jMvduqnpgkJiSnr0fa+94Rwa1pAhxHLFMo5s4 +fqZOtqKPj2s5X1JR0VCey1ilCcaAhWeb3tXCpbYLZSbMtjtqwA6LUeGY+Xdupsjw +cfWIcOfHsIm2kP+RCxEnZf1XwiN9wyJeiUKlE0dqmx9j7F0Bm+7YDhI= +-----END RSA PRIVATE KEY-----`) +) + +func tlsConfig() *tls.Config { + cert, err := tls.X509KeyPair(rawCert, rawKey) + if err != nil { + panic(err) + } + return &tls.Config{Certificates: []tls.Certificate{cert}} +} diff --git a/examples/ssh/sshc.go b/examples/ssh/sshc.go new file mode 100644 index 0000000..a90a026 --- /dev/null +++ b/examples/ssh/sshc.go @@ -0,0 +1,113 @@ +package main + +import ( + "crypto/tls" + "flag" + "log" + "time" + + "github.com/ginuerzh/gost" +) + +var ( + laddr, faddr string + quiet bool +) + +func init() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + flag.StringVar(&laddr, "L", ":18080", "listen address") + flag.StringVar(&faddr, "F", ":12222", "forward address") + flag.BoolVar(&quiet, "q", false, "quiet mode") + flag.BoolVar(&gost.Debug, "d", false, "debug mode") + flag.Parse() + + if quiet { + gost.SetLogger(&gost.NopLogger{}) + } +} + +func main() { + chain := gost.NewChain( + gost.Node{ + Protocol: "socks5", + Transport: "ssh", + Addr: faddr, + HandshakeOptions: []gost.HandshakeOption{ + gost.IntervalHandshakeOption(30 * time.Second), + }, + Client: &gost.Client{ + Connector: gost.SOCKS5Connector(nil), + Transporter: gost.SSHTunnelTransporter(), + }, + }, + ) + + ln, err := gost.TCPListener(laddr) + if err != nil { + log.Fatal(err) + } + h := gost.SOCKS5Handler( + gost.ChainHandlerOption(chain), + gost.TLSConfigHandlerOption(tlsConfig()), + ) + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +var ( + rawCert = []byte(`-----BEGIN CERTIFICATE----- +MIIC+jCCAeKgAwIBAgIRAMlREhz8Miu1FQozsxbeqyMwDQYJKoZIhvcNAQELBQAw +EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xNzA1MTkwNTM5MDJaFw0xODA1MTkwNTM5 +MDJaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw +ggEKAoIBAQCyfqvv0kDriciEAVIW6JaWYFCL9a19jj1wmAGmVGxV3kNsr01kpa6N +0EBqnrcy7WknhCt1d43CqhKtTcXgJ/J9phZVxlizb8sUB85hm+MvP0N3HCg3f0Jw +hLuMrPijS6xjyw0fKCK/p6OUYMIfo5cdqeZid2WV4Ozts5uRd6Dmy2kyBe8Zg1F4 +8YJGuTWZmL2L7uZUiPY4T3q9+1iucq3vUpxymVRi1BTXnTpx+C0GS8NNgeEmevHv +482vHM5DNflAQ+mvGZvBVduq/AfirCDnt2DIZm1DcZXLrY9F3EPrlRZexmAhCDGR +LIKnMmoGicBM11Aw1fDIfJAHynk43tjPAgMBAAGjSzBJMA4GA1UdDwEB/wQEAwIF +oDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBQGA1UdEQQNMAuC +CWxvY2FsaG9zdDANBgkqhkiG9w0BAQsFAAOCAQEAAx8Lna8DcQv0bRB3L9i2+KRN +l/UhPCoFagxk1cZore4p0w+1m7OgigOoTpg5jh78DzVDhScZlgJ0bBVYp5rojeJS +cBDC9lCDcaXQfFmT5LykCAwIgw/gs+rw5Aq0y3D0m8CcqKosyZa9wnZ2cVy/+45w +emcSdboc65ueZScv38/W7aTUoVRcjyRUv0jv0zW0EPnnDlluVkeZo9spBhiTTwoj +b3zGODs6alTNIJwZIHNxxyOmfJPpVVp8BzGbMk7YBixSlZ/vbrrYV34TcSiy7J57 +lNNoVWM+OwiVk1+AEZfQDwaQfef5tsIkAZBUyITkkDKRhygtwM2110dejbEsgg== +-----END CERTIFICATE-----`) + rawKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAsn6r79JA64nIhAFSFuiWlmBQi/WtfY49cJgBplRsVd5DbK9N +ZKWujdBAap63Mu1pJ4QrdXeNwqoSrU3F4CfyfaYWVcZYs2/LFAfOYZvjLz9Ddxwo +N39CcIS7jKz4o0usY8sNHygiv6ejlGDCH6OXHanmYndlleDs7bObkXeg5stpMgXv +GYNRePGCRrk1mZi9i+7mVIj2OE96vftYrnKt71KccplUYtQU1506cfgtBkvDTYHh +Jnrx7+PNrxzOQzX5QEPprxmbwVXbqvwH4qwg57dgyGZtQ3GVy62PRdxD65UWXsZg +IQgxkSyCpzJqBonATNdQMNXwyHyQB8p5ON7YzwIDAQABAoIBAQCG4doj3Apa8z+n +IShbT1+cOyQi34A+xOIA151Hh7xmFxN0afRd/iWt3JUQ/OcLgQRZbDM7DSD+3W5H +r+G7xfQkpwFxx/T3g58+f7ehYx+GcJQWyhxJ88zNIkBnyb4KCAE5WBOOW9IGajPe +yE9pgUGMlPsXpYoKfHIOHg+NGY1pWUGBfBNR2kGrbkpZMmyy5bGa8dyrwAFBFRru +kcmmKvate8UlbRspFtd4nR/GQLTBrcDJ1k1i1Su/4BpDuDeK6LPI8ZRePGqbdcxk +TS30lsdYozuGfjZ5Zu8lSIJ//+7RjfDg8r684dpWjpalq8Quen60ZrIs01CSbfyU +k8gOzTHhAoGBAOKhp41wXveegq+WylSXFyngm4bzF4dVdTRsSbJVk7NaOx1vCU6o +/xIHoGEQyLI6wF+EaHmY89/Qu6tSV97XyBbiKeskopv5iXS/BsWTHJ1VbCA1ZLmK +HgGllEkS0xfc9AdB7b6/K7LxAAQVKP3DtN6+6pSDZh9Sv2M1j0DbhkNbAoGBAMmg +HcMfExaaeskjHqyLudtKX+znwaIoumleOGuavohR4R+Fpk8Yv8Xhb5U7Yr4gk0vY +CFmhp1WAi6QMZ/8jePlKKXl3Ney827luoKiMczp2DoYE0t0u2Kw3LfkNKfjADZ7d +JI6xPJV9/X1erwjq+4UdKqrpOf05SY4nkBMcvr6dAoGAXzisvbDJNiFTp5Mj0Abr +pJzKvBjHegVeCXi2PkfWlzUCQYu1zWcURO8PY7k5mik1SuzHONAbJ578Oy+N3AOt +/m9oTXRHHmHqbzMUFU+KZlDN7XqBp7NwiCCZ/Vn7d7tOjP4Wdl68baL07sI1RupD +xJNS3LOY5PBPmc+XMRkLgKECgYEAgBNDlJSCrZMHeAjlDTncn53I/VXiPD2e3BvL +vx6W9UT9ueZN1GSmPO6M0MDeYmOS7VSXSUhUYQ28pkJzNTC1QbWITu4YxP7anBnX +1/kPoQ0pAJzDzVharlqGy3M/PBHTFRzogfO3xkY35ZFlokaR6uayGcr42Q+w16nt +7RYPXEkCgYEA3GQYirHnGZuQ952jMvduqnpgkJiSnr0fa+94Rwa1pAhxHLFMo5s4 +fqZOtqKPj2s5X1JR0VCey1ilCcaAhWeb3tXCpbYLZSbMtjtqwA6LUeGY+Xdupsjw +cfWIcOfHsIm2kP+RCxEnZf1XwiN9wyJeiUKlE0dqmx9j7F0Bm+7YDhI= +-----END RSA PRIVATE KEY-----`) +) + +func tlsConfig() *tls.Config { + cert, err := tls.X509KeyPair(rawCert, rawKey) + if err != nil { + panic(err) + } + return &tls.Config{Certificates: []tls.Certificate{cert}} +} diff --git a/examples/ssh/sshd.go b/examples/ssh/sshd.go new file mode 100644 index 0000000..72aa2f7 --- /dev/null +++ b/examples/ssh/sshd.go @@ -0,0 +1,99 @@ +package main + +import ( + "crypto/tls" + "flag" + "log" + + "github.com/ginuerzh/gost" +) + +var ( + laddr string + quiet bool +) + +func init() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + flag.StringVar(&laddr, "L", ":12222", "listen address") + flag.BoolVar(&quiet, "q", false, "quiet mode") + flag.BoolVar(&gost.Debug, "d", false, "debug mode") + + flag.Parse() + + if quiet { + gost.SetLogger(&gost.NopLogger{}) + } +} + +func main() { + sshTunnelServer() +} + +func sshTunnelServer() { + ln, err := gost.SSHTunnelListener(laddr, &gost.SSHConfig{TLSConfig: tlsConfig()}) + if err != nil { + log.Fatal(err) + } + h := gost.SOCKS5Handler(gost.TLSConfigHandlerOption(tlsConfig())) + log.Println("server listen on", laddr) + s := &gost.Server{ln} + log.Fatal(s.Serve(h)) +} + +var ( + rawCert = []byte(`-----BEGIN CERTIFICATE----- +MIIC+jCCAeKgAwIBAgIRAMlREhz8Miu1FQozsxbeqyMwDQYJKoZIhvcNAQELBQAw +EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xNzA1MTkwNTM5MDJaFw0xODA1MTkwNTM5 +MDJaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw +ggEKAoIBAQCyfqvv0kDriciEAVIW6JaWYFCL9a19jj1wmAGmVGxV3kNsr01kpa6N +0EBqnrcy7WknhCt1d43CqhKtTcXgJ/J9phZVxlizb8sUB85hm+MvP0N3HCg3f0Jw +hLuMrPijS6xjyw0fKCK/p6OUYMIfo5cdqeZid2WV4Ozts5uRd6Dmy2kyBe8Zg1F4 +8YJGuTWZmL2L7uZUiPY4T3q9+1iucq3vUpxymVRi1BTXnTpx+C0GS8NNgeEmevHv +482vHM5DNflAQ+mvGZvBVduq/AfirCDnt2DIZm1DcZXLrY9F3EPrlRZexmAhCDGR +LIKnMmoGicBM11Aw1fDIfJAHynk43tjPAgMBAAGjSzBJMA4GA1UdDwEB/wQEAwIF +oDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBQGA1UdEQQNMAuC +CWxvY2FsaG9zdDANBgkqhkiG9w0BAQsFAAOCAQEAAx8Lna8DcQv0bRB3L9i2+KRN +l/UhPCoFagxk1cZore4p0w+1m7OgigOoTpg5jh78DzVDhScZlgJ0bBVYp5rojeJS +cBDC9lCDcaXQfFmT5LykCAwIgw/gs+rw5Aq0y3D0m8CcqKosyZa9wnZ2cVy/+45w +emcSdboc65ueZScv38/W7aTUoVRcjyRUv0jv0zW0EPnnDlluVkeZo9spBhiTTwoj +b3zGODs6alTNIJwZIHNxxyOmfJPpVVp8BzGbMk7YBixSlZ/vbrrYV34TcSiy7J57 +lNNoVWM+OwiVk1+AEZfQDwaQfef5tsIkAZBUyITkkDKRhygtwM2110dejbEsgg== +-----END CERTIFICATE-----`) + rawKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAsn6r79JA64nIhAFSFuiWlmBQi/WtfY49cJgBplRsVd5DbK9N +ZKWujdBAap63Mu1pJ4QrdXeNwqoSrU3F4CfyfaYWVcZYs2/LFAfOYZvjLz9Ddxwo +N39CcIS7jKz4o0usY8sNHygiv6ejlGDCH6OXHanmYndlleDs7bObkXeg5stpMgXv +GYNRePGCRrk1mZi9i+7mVIj2OE96vftYrnKt71KccplUYtQU1506cfgtBkvDTYHh +Jnrx7+PNrxzOQzX5QEPprxmbwVXbqvwH4qwg57dgyGZtQ3GVy62PRdxD65UWXsZg +IQgxkSyCpzJqBonATNdQMNXwyHyQB8p5ON7YzwIDAQABAoIBAQCG4doj3Apa8z+n +IShbT1+cOyQi34A+xOIA151Hh7xmFxN0afRd/iWt3JUQ/OcLgQRZbDM7DSD+3W5H +r+G7xfQkpwFxx/T3g58+f7ehYx+GcJQWyhxJ88zNIkBnyb4KCAE5WBOOW9IGajPe +yE9pgUGMlPsXpYoKfHIOHg+NGY1pWUGBfBNR2kGrbkpZMmyy5bGa8dyrwAFBFRru +kcmmKvate8UlbRspFtd4nR/GQLTBrcDJ1k1i1Su/4BpDuDeK6LPI8ZRePGqbdcxk +TS30lsdYozuGfjZ5Zu8lSIJ//+7RjfDg8r684dpWjpalq8Quen60ZrIs01CSbfyU +k8gOzTHhAoGBAOKhp41wXveegq+WylSXFyngm4bzF4dVdTRsSbJVk7NaOx1vCU6o +/xIHoGEQyLI6wF+EaHmY89/Qu6tSV97XyBbiKeskopv5iXS/BsWTHJ1VbCA1ZLmK +HgGllEkS0xfc9AdB7b6/K7LxAAQVKP3DtN6+6pSDZh9Sv2M1j0DbhkNbAoGBAMmg +HcMfExaaeskjHqyLudtKX+znwaIoumleOGuavohR4R+Fpk8Yv8Xhb5U7Yr4gk0vY +CFmhp1WAi6QMZ/8jePlKKXl3Ney827luoKiMczp2DoYE0t0u2Kw3LfkNKfjADZ7d +JI6xPJV9/X1erwjq+4UdKqrpOf05SY4nkBMcvr6dAoGAXzisvbDJNiFTp5Mj0Abr +pJzKvBjHegVeCXi2PkfWlzUCQYu1zWcURO8PY7k5mik1SuzHONAbJ578Oy+N3AOt +/m9oTXRHHmHqbzMUFU+KZlDN7XqBp7NwiCCZ/Vn7d7tOjP4Wdl68baL07sI1RupD +xJNS3LOY5PBPmc+XMRkLgKECgYEAgBNDlJSCrZMHeAjlDTncn53I/VXiPD2e3BvL +vx6W9UT9ueZN1GSmPO6M0MDeYmOS7VSXSUhUYQ28pkJzNTC1QbWITu4YxP7anBnX +1/kPoQ0pAJzDzVharlqGy3M/PBHTFRzogfO3xkY35ZFlokaR6uayGcr42Q+w16nt +7RYPXEkCgYEA3GQYirHnGZuQ952jMvduqnpgkJiSnr0fa+94Rwa1pAhxHLFMo5s4 +fqZOtqKPj2s5X1JR0VCey1ilCcaAhWeb3tXCpbYLZSbMtjtqwA6LUeGY+Xdupsjw +cfWIcOfHsIm2kP+RCxEnZf1XwiN9wyJeiUKlE0dqmx9j7F0Bm+7YDhI= +-----END RSA PRIVATE KEY-----`) +) + +func tlsConfig() *tls.Config { + cert, err := tls.X509KeyPair(rawCert, rawKey) + if err != nil { + panic(err) + } + return &tls.Config{Certificates: []tls.Certificate{cert}} +} diff --git a/examples/ssu/ssu.go b/examples/ssu/ssu.go new file mode 100644 index 0000000..6aeee1c --- /dev/null +++ b/examples/ssu/ssu.go @@ -0,0 +1,65 @@ +package main + +import ( + "bytes" + "log" + "net" + "strconv" + + "github.com/go-gost/gosocks5" + ss "github.com/shadowsocks/shadowsocks-go/shadowsocks" +) + +func main() { + ssuClient() +} + +func ssuClient() { + addr, err := net.ResolveUDPAddr("udp", ":18338") + if err != nil { + log.Fatal(err) + } + laddr, _ := net.ResolveUDPAddr("udp", ":10800") + conn, err := net.ListenUDP("udp", laddr) + if err != nil { + log.Fatal(err) + } + cp, err := ss.NewCipher("chacha20", "123456") + if err != nil { + log.Fatal(err) + } + cc := ss.NewSecurePacketConn(conn, cp, false) + + raddr, _ := net.ResolveUDPAddr("udp", ":8080") + msg := []byte(`abcdefghijklmnopqrstuvwxyz`) + dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, toSocksAddr(raddr)), msg) + buf := bytes.Buffer{} + dgram.Write(&buf) + for { + log.Printf("%# x", buf.Bytes()[3:]) + if _, err := cc.WriteTo(buf.Bytes()[3:], addr); err != nil { + log.Fatal(err) + } + b := make([]byte, 1024) + n, adr, err := cc.ReadFrom(b) + if err != nil { + log.Fatal(err) + } + log.Printf("%s: %# x", adr, b[:n]) + } +} + +func toSocksAddr(addr net.Addr) *gosocks5.Addr { + host := "0.0.0.0" + port := 0 + if addr != nil { + h, p, _ := net.SplitHostPort(addr.String()) + host = h + port, _ = strconv.Atoi(p) + } + return &gosocks5.Addr{ + Type: gosocks5.AddrIPv4, + Host: host, + Port: uint16(port), + } +} diff --git a/forward.go b/forward.go new file mode 100644 index 0000000..a419735 --- /dev/null +++ b/forward.go @@ -0,0 +1,790 @@ +package gost + +import ( + "context" + "errors" + "net" + "strings" + "sync" + "time" + + "fmt" + + "github.com/go-gost/gosocks5" + "github.com/go-log/log" + smux "github.com/xtaci/smux" +) + +type forwardConnector struct { +} + +// ForwardConnector creates a Connector for data forward client. +func ForwardConnector() Connector { + return &forwardConnector{} +} + +func (c *forwardConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + +func (c *forwardConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + return conn, nil +} + +type baseForwardHandler struct { + raddr string + group *NodeGroup + options *HandlerOptions +} + +func (h *baseForwardHandler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} + } + + for _, opt := range options { + opt(h.options) + } + + h.group = NewNodeGroup() // reset node group + + h.group.SetSelector(&defaultSelector{}, + WithStrategy(h.options.Strategy), + WithFilter(&FailFilter{ + MaxFails: h.options.MaxFails, + FailTimeout: h.options.FailTimeout, + }), + ) + + n := 1 + addrs := append(strings.Split(h.raddr, ","), h.options.IPs...) + for _, addr := range addrs { + if addr == "" { + continue + } + + // We treat the remote target server as a node, so we can put them in a group, + // and perform the node selection for load balancing. + h.group.AddNode(Node{ + ID: n, + Addr: addr, + Host: addr, + marker: &failMarker{}, + }) + + n++ + } +} + +type tcpDirectForwardHandler struct { + *baseForwardHandler +} + +// TCPDirectForwardHandler creates a server Handler for TCP port forwarding server. +// The raddr is the remote address that the server will forward to. +// NOTE: as of 2.6, remote address can be a comma-separated address list. +func TCPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { + h := &tcpDirectForwardHandler{ + baseForwardHandler: &baseForwardHandler{ + raddr: raddr, + group: NewNodeGroup(), + options: &HandlerOptions{}, + }, + } + + for _, opt := range opts { + opt(h.options) + } + + return h +} + +func (h *tcpDirectForwardHandler) Init(options ...HandlerOption) { + h.baseForwardHandler.Init(options...) +} + +func (h *tcpDirectForwardHandler) Handle(conn net.Conn) { + defer conn.Close() + + log.Logf("[tcp] %s - %s", conn.RemoteAddr(), conn.LocalAddr()) + + retries := 1 + if h.options.Chain != nil && h.options.Chain.Retries > 0 { + retries = h.options.Chain.Retries + } + if h.options.Retries > 0 { + retries = h.options.Retries + } + + var cc net.Conn + var node Node + var err error + for i := 0; i < retries; i++ { + if len(h.group.Nodes()) > 0 { + node, err = h.group.Next() + if err != nil { + log.Logf("[tcp] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + } + + cc, err = h.options.Chain.Dial(node.Addr, + RetryChainOption(h.options.Retries), + TimeoutChainOption(h.options.Timeout), + ) + if err != nil { + log.Logf("[tcp] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + node.MarkDead() + } else { + break + } + } + if err != nil { + return + } + + node.ResetDead() + defer cc.Close() + + addr := node.Addr + if addr == "" { + addr = conn.LocalAddr().String() + } + log.Logf("[tcp] %s <-> %s", conn.RemoteAddr(), addr) + transport(conn, cc) + log.Logf("[tcp] %s >-< %s", conn.RemoteAddr(), addr) +} + +type udpDirectForwardHandler struct { + *baseForwardHandler +} + +// UDPDirectForwardHandler creates a server Handler for UDP port forwarding server. +// The raddr is the remote address that the server will forward to. +// NOTE: as of 2.6, remote address can be a comma-separated address list. +func UDPDirectForwardHandler(raddr string, opts ...HandlerOption) Handler { + h := &udpDirectForwardHandler{ + baseForwardHandler: &baseForwardHandler{ + raddr: raddr, + group: NewNodeGroup(), + options: &HandlerOptions{}, + }, + } + + for _, opt := range opts { + opt(h.options) + } + + return h +} + +func (h *udpDirectForwardHandler) Init(options ...HandlerOption) { + h.baseForwardHandler.Init(options...) +} + +func (h *udpDirectForwardHandler) Handle(conn net.Conn) { + defer conn.Close() + + log.Logf("[udp] %s - %s", conn.RemoteAddr(), conn.LocalAddr()) + + var node Node + var err error + if len(h.group.Nodes()) > 0 { + node, err = h.group.Next() + if err != nil { + log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + } + + cc, err := h.options.Chain.DialContext(context.Background(), "udp", node.Addr) + if err != nil { + node.MarkDead() + log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + defer cc.Close() + node.ResetDead() + + addr := node.Addr + if addr == "" { + addr = conn.LocalAddr().String() + } + log.Logf("[udp] %s <-> %s", conn.RemoteAddr(), addr) + transport(conn, cc) + log.Logf("[udp] %s >-< %s", conn.RemoteAddr(), addr) +} + +type tcpRemoteForwardHandler struct { + *baseForwardHandler +} + +// TCPRemoteForwardHandler creates a server Handler for TCP remote port forwarding server. +// The raddr is the remote address that the server will forward to. +// NOTE: as of 2.6, remote address can be a comma-separated address list. +func TCPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { + h := &tcpRemoteForwardHandler{ + baseForwardHandler: &baseForwardHandler{ + raddr: raddr, + group: NewNodeGroup(), + options: &HandlerOptions{}, + }, + } + + for _, opt := range opts { + opt(h.options) + } + + return h +} + +func (h *tcpRemoteForwardHandler) Init(options ...HandlerOption) { + h.baseForwardHandler.Init(options...) +} + +func (h *tcpRemoteForwardHandler) Handle(conn net.Conn) { + defer conn.Close() + + retries := 1 + if h.options.Retries > 0 { + retries = h.options.Retries + } + + var cc net.Conn + var node Node + var err error + for i := 0; i < retries; i++ { + if len(h.group.Nodes()) > 0 { + node, err = h.group.Next() + if err != nil { + log.Logf("[rtcp] %s - %s : %s", conn.LocalAddr(), h.raddr, err) + return + } + } + cc, err = net.DialTimeout("tcp", node.Addr, h.options.Timeout) + if err != nil { + log.Logf("[rtcp] %s -> %s : %s", conn.LocalAddr(), node.Addr, err) + node.MarkDead() + } else { + break + } + } + if err != nil { + return + } + + defer cc.Close() + node.ResetDead() + + log.Logf("[rtcp] %s <-> %s", conn.LocalAddr(), node.Addr) + transport(cc, conn) + log.Logf("[rtcp] %s >-< %s", conn.LocalAddr(), node.Addr) +} + +type udpRemoteForwardHandler struct { + *baseForwardHandler +} + +// UDPRemoteForwardHandler creates a server Handler for UDP remote port forwarding server. +// The raddr is the remote address that the server will forward to. +// NOTE: as of 2.6, remote address can be a comma-separated address list. +func UDPRemoteForwardHandler(raddr string, opts ...HandlerOption) Handler { + h := &udpRemoteForwardHandler{ + baseForwardHandler: &baseForwardHandler{ + raddr: raddr, + group: NewNodeGroup(), + options: &HandlerOptions{}, + }, + } + + for _, opt := range opts { + opt(h.options) + } + + return h +} + +func (h *udpRemoteForwardHandler) Init(options ...HandlerOption) { + h.baseForwardHandler.Init(options...) +} + +func (h *udpRemoteForwardHandler) Handle(conn net.Conn) { + defer conn.Close() + + var node Node + var err error + if len(h.group.Nodes()) > 0 { + node, err = h.group.Next() + if err != nil { + log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), h.raddr, err) + return + } + } + + raddr, err := net.ResolveUDPAddr("udp", node.Addr) + if err != nil { + node.MarkDead() + log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err) + return + } + cc, err := net.DialUDP("udp", nil, raddr) + if err != nil { + node.MarkDead() + log.Logf("[rudp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err) + return + } + defer cc.Close() + node.ResetDead() + + log.Logf("[rudp] %s <-> %s", conn.RemoteAddr(), node.Addr) + transport(conn, cc) + log.Logf("[rudp] %s >-< %s", conn.RemoteAddr(), node.Addr) +} + +type tcpRemoteForwardListener struct { + addr net.Addr + chain *Chain + connChan chan net.Conn + ln net.Listener + session *muxSession + sessionMux sync.Mutex + closed chan struct{} + closeMux sync.Mutex +} + +// TCPRemoteForwardListener creates a Listener for TCP remote port forwarding server. +func TCPRemoteForwardListener(addr string, chain *Chain) (Listener, error) { + laddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + + ln := &tcpRemoteForwardListener{ + addr: laddr, + chain: chain, + connChan: make(chan net.Conn, 1024), + closed: make(chan struct{}), + } + + if !ln.isChainValid() { + ln.ln, err = net.Listen("tcp", ln.addr.String()) + return ln, err + } + + go ln.listenLoop() + + return ln, err +} + +func (l *tcpRemoteForwardListener) isChainValid() bool { + if l.chain.IsEmpty() { + return false + } + + lastNode := l.chain.LastNode() + if (lastNode.Protocol == "forward" && lastNode.Transport == "ssh") || + lastNode.Protocol == "socks5" || lastNode.Protocol == "" { + return true + } + return false +} + +func (l *tcpRemoteForwardListener) listenLoop() { + var tempDelay time.Duration + + for { + conn, err := l.accept() + + select { + case <-l.closed: + if conn != nil { + conn.Close() + } + return + default: + } + + if err != nil { + if tempDelay == 0 { + tempDelay = 1000 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 6 * time.Second; tempDelay > max { + tempDelay = max + } + log.Logf("[rtcp] accept error: %v; retrying in %v", err, tempDelay) + time.Sleep(tempDelay) + continue + } + + tempDelay = 0 + + select { + case l.connChan <- conn: + default: + conn.Close() + log.Logf("[rtcp] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr()) + } + } +} + +func (l *tcpRemoteForwardListener) Accept() (conn net.Conn, err error) { + if l.ln != nil { + return l.ln.Accept() + } + + select { + case conn = <-l.connChan: + case <-l.closed: + err = errors.New("closed") + } + + return +} + +func (l *tcpRemoteForwardListener) accept() (conn net.Conn, err error) { + lastNode := l.chain.LastNode() + if lastNode.Protocol == "forward" && lastNode.Transport == "ssh" { + return l.chain.Dial(l.addr.String()) + } + + if l.isChainValid() { + if lastNode.GetBool("mbind") { + return l.muxAccept() // multiplexing support for binding. + } + + cc, er := l.chain.Conn() + if er != nil { + return nil, er + } + conn, err = l.waitConnectSOCKS5(cc) + if err != nil { + cc.Close() + } + } + return +} + +func (l *tcpRemoteForwardListener) muxAccept() (conn net.Conn, err error) { + session, err := l.getSession() + if err != nil { + return nil, err + } + cc, err := session.Accept() + if err != nil { + session.Close() + return nil, err + } + + return cc, nil +} + +func (l *tcpRemoteForwardListener) getSession() (s *muxSession, err error) { + l.sessionMux.Lock() + defer l.sessionMux.Unlock() + + if l.session != nil && !l.session.IsClosed() { + return l.session, nil + } + + conn, err := l.chain.Conn() + if err != nil { + return nil, err + } + + defer func(c net.Conn) { + if err != nil { + c.Close() + } + }(conn) + + conn.SetDeadline(time.Now().Add(HandshakeTimeout)) + defer conn.SetDeadline(time.Time{}) + + conn, err = socks5Handshake(conn, userSocks5HandshakeOption(l.chain.LastNode().User)) + if err != nil { + return nil, err + } + req := gosocks5.NewRequest(CmdMuxBind, toSocksAddr(l.addr)) + if err := req.Write(conn); err != nil { + log.Log("[rtcp] SOCKS5 BIND request: ", err) + return nil, err + } + + rep, err := gosocks5.ReadReply(conn) + if err != nil { + log.Log("[rtcp] SOCKS5 BIND reply: ", err) + return nil, err + } + if rep.Rep != gosocks5.Succeeded { + log.Logf("[rtcp] bind on %s failure", l.addr) + return nil, fmt.Errorf("Bind on %s failure", l.addr.String()) + } + log.Logf("[rtcp] BIND ON %s OK", rep.Addr) + + // Upgrade connection to multiplex stream. + session, err := smux.Server(conn, smux.DefaultConfig()) + if err != nil { + return nil, err + } + l.session = &muxSession{ + conn: conn, + session: session, + } + + return l.session, nil +} + +func (l *tcpRemoteForwardListener) waitConnectSOCKS5(conn net.Conn) (net.Conn, error) { + conn, err := socks5Handshake(conn, userSocks5HandshakeOption(l.chain.LastNode().User)) + if err != nil { + return nil, err + } + req := gosocks5.NewRequest(gosocks5.CmdBind, toSocksAddr(l.addr)) + if err := req.Write(conn); err != nil { + log.Log("[rtcp] SOCKS5 BIND request: ", err) + return nil, err + } + + // first reply, bind status + conn.SetReadDeadline(time.Now().Add(ReadTimeout)) + rep, err := gosocks5.ReadReply(conn) + if err != nil { + log.Log("[rtcp] SOCKS5 BIND reply: ", err) + return nil, err + } + conn.SetReadDeadline(time.Time{}) + if rep.Rep != gosocks5.Succeeded { + log.Logf("[rtcp] bind on %s failure", l.addr) + return nil, fmt.Errorf("Bind on %s failure", l.addr.String()) + } + log.Logf("[rtcp] BIND ON %s OK", rep.Addr) + + // second reply, peer connected + rep, err = gosocks5.ReadReply(conn) + if err != nil { + log.Log("[rtcp]", err) + return nil, err + } + if rep.Rep != gosocks5.Succeeded { + log.Logf("[rtcp] peer connect failure: %d", rep.Rep) + return nil, errors.New("peer connect failure") + } + + log.Logf("[rtcp] PEER %s CONNECTED", rep.Addr) + return conn, nil +} + +func (l *tcpRemoteForwardListener) Addr() net.Addr { + if l.ln != nil { + return l.ln.Addr() + } + return l.addr +} + +func (l *tcpRemoteForwardListener) Close() error { + if l.ln != nil { + return l.ln.Close() + } + + l.closeMux.Lock() + defer l.closeMux.Unlock() + + select { + case <-l.closed: + return nil + default: + close(l.closed) + } + return nil +} + +type udpRemoteForwardListener struct { + addr net.Addr + chain *Chain + connMap *udpConnMap + connChan chan net.Conn + ln *net.UDPConn + ttl time.Duration + closed chan struct{} + ready chan struct{} + once sync.Once + closeMux sync.Mutex + config *UDPListenConfig +} + +// UDPRemoteForwardListener creates a Listener for UDP remote port forwarding server. +func UDPRemoteForwardListener(addr string, chain *Chain, cfg *UDPListenConfig) (Listener, error) { + laddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + + if cfg == nil { + cfg = &UDPListenConfig{} + } + + backlog := cfg.Backlog + if backlog <= 0 { + backlog = defaultBacklog + } + + ln := &udpRemoteForwardListener{ + addr: laddr, + chain: chain, + connMap: new(udpConnMap), + connChan: make(chan net.Conn, backlog), + ready: make(chan struct{}), + closed: make(chan struct{}), + config: cfg, + } + + go ln.listenLoop() + + <-ln.ready + + return ln, err +} + +func (l *udpRemoteForwardListener) isChainValid() bool { + if l.chain.IsEmpty() { + return false + } + + lastNode := l.chain.LastNode() + return lastNode.Protocol == "socks5" || lastNode.Protocol == "" +} + +func (l *udpRemoteForwardListener) listenLoop() { + for { + conn, err := l.connect() + if err != nil { + log.Logf("[rudp] %s : %s", l.Addr(), err) + return + } + + l.once.Do(func() { + close(l.ready) + }) + + func() { + defer conn.Close() + + for { + b := make([]byte, mediumBufferSize) + n, raddr, err := conn.ReadFrom(b) + if err != nil { + log.Logf("[rudp] %s : %s", l.Addr(), err) + break + } + + uc, ok := l.connMap.Get(raddr.String()) + if !ok { + uc = newUDPServerConn(conn, raddr, &udpServerConnConfig{ + ttl: l.config.TTL, + qsize: l.config.QueueSize, + onClose: func() { + l.connMap.Delete(raddr.String()) + log.Logf("[rudp] %s closed (%d)", raddr, l.connMap.Size()) + }, + }) + + select { + case l.connChan <- uc: + l.connMap.Set(raddr.String(), uc) + log.Logf("[rudp] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size()) + default: + uc.Close() + log.Logf("[rudp] %s - %s: connection queue is full (%d)", + raddr, l.Addr(), cap(l.connChan)) + } + } + + select { + case uc.rChan <- b[:n]: + if Debug { + log.Logf("[rudp] %s >>> %s : length %d", raddr, l.Addr(), n) + } + default: + log.Logf("[rudp] %s -> %s : recv queue is full", raddr, l.Addr(), cap(uc.rChan)) + } + } + }() + } +} + +func (l *udpRemoteForwardListener) connect() (conn net.PacketConn, err error) { + var tempDelay time.Duration + + for { + select { + case <-l.closed: + return nil, errors.New("closed") + default: + } + + if l.isChainValid() { + var cc net.Conn + cc, err = getSocks5UDPTunnel(l.chain, l.addr) + if err != nil { + log.Logf("[rudp] %s : %s", l.Addr(), err) + } else { + conn = cc.(net.PacketConn) + } + } else { + var uc *net.UDPConn + uc, err = net.ListenUDP("udp", l.addr.(*net.UDPAddr)) + if err == nil { + l.addr = uc.LocalAddr() + conn = uc + } + } + + if err != nil { + if tempDelay == 0 { + tempDelay = 1000 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 6 * time.Second; tempDelay > max { + tempDelay = max + } + log.Logf("[rudp] Accept error: %v; retrying in %v", err, tempDelay) + time.Sleep(tempDelay) + continue + } + return + } +} + +func (l *udpRemoteForwardListener) Accept() (conn net.Conn, err error) { + select { + case conn = <-l.connChan: + case <-l.closed: + err = errors.New("accpet on closed listener") + } + return +} + +func (l *udpRemoteForwardListener) Addr() net.Addr { + return l.addr +} + +func (l *udpRemoteForwardListener) Close() error { + l.closeMux.Lock() + defer l.closeMux.Unlock() + + select { + case <-l.closed: + return nil + default: + l.connMap.Range(func(k interface{}, v *udpServerConn) bool { + v.Close() + return true + }) + close(l.closed) + } + + return nil +} diff --git a/forward_test.go b/forward_test.go new file mode 100644 index 0000000..d47c290 --- /dev/null +++ b/forward_test.go @@ -0,0 +1,317 @@ +package gost + +import ( + "crypto/rand" + "net/http/httptest" + "net/url" + "testing" +) + +func tcpDirectForwardRoundtrip(targetURL string, data []byte) error { + ln, err := TCPListener("") + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: TCPTransporter(), + } + + h := TCPDirectForwardHandler(u.Host) + h.Init() + server := &Server{ + Listener: ln, + Handler: h, + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestTCPDirectForward(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := tcpDirectForwardRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} + +func BenchmarkTCPDirectForward(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: TCPTransporter(), + } + + u, err := url.Parse(httpSrv.URL) + if err != nil { + b.Error(err) + } + + h := TCPDirectForwardHandler(u.Host) + h.Init() + server := &Server{ + Listener: ln, + Handler: h, + } + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkTCPDirectForwardParallel(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: TCPTransporter(), + } + + u, err := url.Parse(httpSrv.URL) + if err != nil { + b.Error(err) + } + + h := TCPDirectForwardHandler(u.Host) + h.Init() + server := &Server{ + Listener: ln, + Handler: h, + } + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } + }) +} + +func udpDirectForwardRoundtrip(t *testing.T, host string, data []byte) error { + ln, err := UDPListener("localhost:0", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: UDPTransporter(), + } + + h := UDPDirectForwardHandler(host) + h.Init() + server := &Server{ + Listener: ln, + Handler: h, + } + + go server.Run() + defer server.Close() + + return udpRoundtrip(t, client, server, host, data) +} + +func TestUDPDirectForward(t *testing.T) { + udpSrv := newUDPTestServer(udpTestHandler) + udpSrv.Start() + defer udpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + err := udpDirectForwardRoundtrip(t, udpSrv.Addr(), sendData) + if err != nil { + t.Error(err) + } +} + +func BenchmarkUDPDirectForward(b *testing.B) { + udpSrv := newUDPTestServer(udpTestHandler) + udpSrv.Start() + defer udpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := UDPListener("localhost:0", nil) + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: UDPTransporter(), + } + + h := UDPDirectForwardHandler(udpSrv.Addr()) + h.Init() + server := &Server{ + Listener: ln, + Handler: h, + } + + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := udpRoundtrip(b, client, server, udpSrv.Addr(), sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkUDPDirectForwardParallel(b *testing.B) { + udpSrv := newUDPTestServer(udpTestHandler) + udpSrv.Start() + defer udpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := UDPListener("localhost:0", nil) + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: UDPTransporter(), + } + + h := UDPDirectForwardHandler(udpSrv.Addr()) + h.Init() + server := &Server{ + Listener: ln, + Handler: h, + } + + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := udpRoundtrip(b, client, server, udpSrv.Addr(), sendData); err != nil { + b.Error(err) + } + } + }) +} + +func tcpRemoteForwardRoundtrip(t *testing.T, targetURL string, data []byte) error { + ln, err := TCPRemoteForwardListener("localhost:0", nil) // listening on localhost + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: TCPTransporter(), + } + + h := TCPRemoteForwardHandler(u.Host) // forward to u.Host + h.Init() + server := &Server{ + Listener: ln, + Handler: h, + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestTCPRemoteForward(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := tcpRemoteForwardRoundtrip(t, httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} + +func udpRemoteForwardRoundtrip(t *testing.T, host string, data []byte) error { + ln, err := UDPRemoteForwardListener("localhost:0", nil, nil) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: UDPTransporter(), + } + + h := UDPRemoteForwardHandler(host) + h.Init() + server := &Server{ + Listener: ln, + Handler: h, + } + + go server.Run() + defer server.Close() + + return udpRoundtrip(t, client, server, host, data) +} + +func TestUDPRemoteForward(t *testing.T) { + udpSrv := newUDPTestServer(udpTestHandler) + udpSrv.Start() + defer udpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := udpRemoteForwardRoundtrip(t, udpSrv.Addr(), sendData) + if err != nil { + t.Error(err) + } +} diff --git a/ftcp.go b/ftcp.go new file mode 100644 index 0000000..a1cfcf0 --- /dev/null +++ b/ftcp.go @@ -0,0 +1,175 @@ +package gost + +import ( + "errors" + "net" + "time" + + "github.com/go-log/log" + "github.com/xtaci/tcpraw" +) + +type fakeTCPTransporter struct{} + +// FakeTCPTransporter creates a Transporter that is used by fake tcp client. +func FakeTCPTransporter() Transporter { + return &fakeTCPTransporter{} +} + +func (tr *fakeTCPTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { + opts := &DialOptions{} + for _, option := range options { + option(opts) + } + + raddr, er := net.ResolveTCPAddr("tcp", addr) + if er != nil { + return nil, er + } + c, err := tcpraw.Dial("tcp", addr) + if err != nil { + return + } + conn = &fakeTCPConn{ + raddr: raddr, + PacketConn: c, + } + return conn, nil +} + +func (tr *fakeTCPTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + return conn, nil +} + +func (tr *fakeTCPTransporter) Multiplex() bool { + return false +} + +// FakeTCPListenConfig is config for fake TCP Listener. +type FakeTCPListenConfig struct { + TTL time.Duration + Backlog int + QueueSize int +} + +type fakeTCPListener struct { + ln net.PacketConn + connChan chan net.Conn + errChan chan error + connMap udpConnMap + config *FakeTCPListenConfig +} + +// FakeTCPListener creates a Listener for fake TCP server. +func FakeTCPListener(addr string, cfg *FakeTCPListenConfig) (Listener, error) { + ln, err := tcpraw.Listen("tcp", addr) + if err != nil { + return nil, err + } + + if cfg == nil { + cfg = &FakeTCPListenConfig{} + } + + backlog := cfg.Backlog + if backlog <= 0 { + backlog = defaultBacklog + } + + l := &fakeTCPListener{ + ln: ln, + connChan: make(chan net.Conn, backlog), + errChan: make(chan error, 1), + config: cfg, + } + go l.listenLoop() + return l, nil +} + +func (l *fakeTCPListener) listenLoop() { + for { + b := make([]byte, mediumBufferSize) + n, raddr, err := l.ln.ReadFrom(b) + if err != nil { + log.Logf("[ftcp] peer -> %s : %s", l.Addr(), err) + l.Close() + l.errChan <- err + close(l.errChan) + return + } + + conn, ok := l.connMap.Get(raddr.String()) + if !ok { + conn = newUDPServerConn(l.ln, raddr, &udpServerConnConfig{ + ttl: l.config.TTL, + qsize: l.config.QueueSize, + onClose: func() { + l.connMap.Delete(raddr.String()) + log.Logf("[ftcp] %s closed (%d)", raddr, l.connMap.Size()) + }, + }) + + select { + case l.connChan <- conn: + l.connMap.Set(raddr.String(), conn) + log.Logf("[ftcp] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size()) + default: + conn.Close() + log.Logf("[ftcp] %s - %s: connection queue is full (%d)", raddr, l.Addr(), cap(l.connChan)) + } + } + + select { + case conn.rChan <- b[:n]: + if Debug { + log.Logf("[ftcp] %s >>> %s : length %d", raddr, l.Addr(), n) + } + default: + log.Logf("[ftcp] %s -> %s : recv queue is full (%d)", raddr, l.Addr(), cap(conn.rChan)) + } + } +} + +func (l *fakeTCPListener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.connChan: + case err, ok = <-l.errChan: + if !ok { + err = errors.New("accpet on closed listener") + } + } + return +} + +func (l *fakeTCPListener) Addr() net.Addr { + return l.ln.LocalAddr() +} + +func (l *fakeTCPListener) Close() error { + err := l.ln.Close() + l.connMap.Range(func(k interface{}, v *udpServerConn) bool { + v.Close() + return true + }) + + return err +} + +type fakeTCPConn struct { + raddr net.Addr + net.PacketConn +} + +func (c *fakeTCPConn) Read(b []byte) (n int, err error) { + n, _, err = c.ReadFrom(b) + return +} + +func (c *fakeTCPConn) Write(b []byte) (n int, err error) { + return c.WriteTo(b, c.raddr) +} + +func (c *fakeTCPConn) RemoteAddr() net.Addr { + return c.raddr +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..7e4606a --- /dev/null +++ b/go.mod @@ -0,0 +1,60 @@ +module github.com/ginuerzh/gost + +go 1.17 + +require ( + git.torproject.org/pluggable-transports/goptlib.git v1.2.0 + github.com/LiamHaworth/go-tproxy v0.0.0-20190726054950-ef7efd7f24ed + github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d + github.com/docker/libcontainer v2.2.1+incompatible + github.com/go-gost/gosocks4 v0.0.1 + github.com/go-gost/gosocks5 v0.3.0 + github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7 + github.com/go-gost/tls-dissector v0.0.2-0.20211125135007-2b5d5bd9c07e + github.com/go-log/log v0.2.0 + github.com/gobwas/glob v0.2.3 + github.com/gorilla/websocket v1.4.2 + github.com/klauspost/compress v1.13.6 + github.com/lucas-clemente/quic-go v0.26.0 + github.com/miekg/dns v1.1.43 + github.com/milosgajdos/tenus v0.0.3 + github.com/ryanuber/go-glob v1.0.0 + github.com/shadowsocks/go-shadowsocks2 v0.1.5 + github.com/shadowsocks/shadowsocks-go v0.0.0-20200409064450-3e585ff90601 + github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 + github.com/xtaci/kcp-go v5.4.20+incompatible + github.com/xtaci/smux v1.5.16 + github.com/xtaci/tcpraw v1.2.25 + gitlab.com/yawning/obfs4.git v0.0.0-20210511220700-e330d1b7024b + golang.org/x/crypto v0.0.0-20220321153916-2c7772ba3064 + golang.org/x/net v0.0.0-20220325170049-de3da57026de +) + +require ( + github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect + github.com/cheekybits/genny v1.0.0 // indirect + github.com/coreos/go-iptables v0.6.0 // indirect + github.com/dchest/siphash v1.2.2 // indirect + github.com/fsnotify/fsnotify v1.5.1 // indirect + github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect + github.com/google/gopacket v1.1.19 // indirect + github.com/klauspost/cpuid/v2 v2.0.9 // indirect + github.com/klauspost/reedsolomon v1.9.15 // indirect + github.com/marten-seemann/qtls-go1-16 v0.1.5 // indirect + github.com/marten-seemann/qtls-go1-17 v0.1.1 // indirect + github.com/marten-seemann/qtls-go1-18 v0.1.1 // indirect + github.com/nxadm/tail v1.4.8 // indirect + github.com/onsi/ginkgo v1.16.5 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 // indirect + github.com/templexxx/cpufeat v0.0.0-20180724012125-cef66df7f161 // indirect + github.com/templexxx/xor v0.0.0-20191217153810-f85b25db303b // indirect + github.com/tjfoc/gmsm v1.4.1 // indirect + github.com/xtaci/lossyconn v0.0.0-20200209145036-adba10fffc37 // indirect + golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect + golang.org/x/sys v0.0.0-20220325203850-36772127a21f // indirect + golang.org/x/text v0.3.7 // indirect + golang.org/x/tools v0.1.10 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..c1a85de --- /dev/null +++ b/go.sum @@ -0,0 +1,407 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.37.0/go.mod h1:TS1dMSSfndXH133OKGwekG838Om/cQT0BUHV3HcBgoo= +dmitri.shuralyov.com/app/changes v0.0.0-20180602232624-0a106ad413e3/go.mod h1:Yl+fi1br7+Rr3LqpNJf1/uxUdtRUV+Tnj0o93V2B9MU= +dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBrvjyP0v+ecvNYvCpyZgu5/xkfAUhi6wJj28eUfSU= +dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4= +dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU= +git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= +git.torproject.org/pluggable-transports/goptlib.git v1.0.0/go.mod h1:YT4XMSkuEXbtqlydr9+OxqFAyspUv0Gr9qhM3B++o/Q= +git.torproject.org/pluggable-transports/goptlib.git v1.2.0 h1:0qRF7Dw5qXd0FtZkjWUiAh5GTutRtDGL4GXUDJ4qMHs= +git.torproject.org/pluggable-transports/goptlib.git v1.2.0/go.mod h1:4PBMl1dg7/3vMWSoWb46eGWlrxkUyn/CAJmxhDLAlDs= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/LiamHaworth/go-tproxy v0.0.0-20190726054950-ef7efd7f24ed h1:eqa6queieK8SvoszxCu0WwH7lSVeL4/N/f1JwOMw1G4= +github.com/LiamHaworth/go-tproxy v0.0.0-20190726054950-ef7efd7f24ed/go.mod h1:rA52xkgZwql9LRZXWb2arHEFP6qSR48KY2xOfWzEciQ= +github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= +github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA= +github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c= +github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d h1:Byv0BzEl3/e6D5CLfI0j/7hiIEtvGVFPCZ7Ei2oq8iQ= +github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g= +github.com/buger/jsonparser v0.0.0-20181115193947-bf1c66bbce23/go.mod h1:bbYlZJ7hK1yFx9hf58LP0zeX7UjIGs20ufpu3evjr+s= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cheekybits/genny v1.0.0 h1:uGGa4nei+j20rOSeDeP5Of12XVm7TGUd4dJA9RDitfE= +github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wXkRAgjxjQ= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/coreos/go-iptables v0.6.0 h1:is9qnZMPYjLd8LYqmm/qlE+wwEgJIkTYdhV3rfZo4jk= +github.com/coreos/go-iptables v0.6.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= +github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dchest/siphash v1.2.1/go.mod h1:q+IRvb2gOSrUnYoPqHiyHXS0FOBBOdl6tONBlVnOnt4= +github.com/dchest/siphash v1.2.2 h1:9DFz8tQwl9pTVt5iok/9zKyzA1Q6bRGiF3HPiEEVr9I= +github.com/dchest/siphash v1.2.2/go.mod h1:q+IRvb2gOSrUnYoPqHiyHXS0FOBBOdl6tONBlVnOnt4= +github.com/docker/libcontainer v2.2.1+incompatible h1:++SbbkCw+X8vAd4j2gOCzZ2Nn7s2xFALTf7LZKmM1/0= +github.com/docker/libcontainer v2.2.1+incompatible/go.mod h1:osvj61pYsqhNCMLGX31xr7klUBhHb/ZBuXS0o1Fvwbw= +github.com/dsnet/compress v0.0.1/go.mod h1:Aw8dCMJ7RioblQeTqt88akK31OvO8Dhf5JflhBbQEHo= +github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= +github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/fsnotify/fsnotify v1.5.1 h1:mZcQUHVQUQWoPXXtuf9yuEXKudkV2sx1E06UadKWpgI= +github.com/fsnotify/fsnotify v1.5.1/go.mod h1:T3375wBYaZdLLcVNkcVbzGHY7f1l/uK5T5Ai1i3InKU= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= +github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= +github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= +github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= +github.com/go-gost/gosocks5 v0.3.0 h1:Hkmp9YDRBSCJd7xywW6dBPT6B9aQTkuWd+3WCheJiJA= +github.com/go-gost/gosocks5 v0.3.0/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= +github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7 h1:itaaJhQJ19kUXEB4Igb0EbY8m+1Py2AaNNSBds/9gk4= +github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8= +github.com/go-gost/tls-dissector v0.0.2-0.20211125135007-2b5d5bd9c07e h1:73NGqAs22ey3wJkIYVD/ACEoovuIuOlEzQTEoqrO5+U= +github.com/go-gost/tls-dissector v0.0.2-0.20211125135007-2b5d5bd9c07e/go.mod h1:/9QfdewqmHdaE362Hv5nDaSWLx3pCmtD870d6GaquXs= +github.com/go-log/log v0.2.0 h1:z8i91GBudxD5L3RmF0KVpetCbcGWAV7q1Tw1eRwQM9Q= +github.com/go-log/log v0.2.0/go.mod h1:xzCnwajcues/6w7lne3yK2QU7DBPW7kqbgPGG5AF65U= +github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= +github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= +github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= +github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= +github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= +github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= +github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= +github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= +github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= +github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= +github.com/klauspost/compress v1.13.6 h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc= +github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/klauspost/cpuid v1.2.0 h1:NMpwD2G9JSFOE1/TJjGSo5zG7Yb2bTe7eq1jH+irmeE= +github.com/klauspost/cpuid v1.2.0/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= +github.com/klauspost/cpuid/v2 v2.0.6/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/reedsolomon v1.9.15 h1:g2erWKD2M6rgnPf89fCji6jNlhMKMdXcuNHMW1SYCIo= +github.com/klauspost/reedsolomon v1.9.15/go.mod h1:eqPAcE7xar5CIzcdfwydOEdcmchAKAP/qs14y4GCBOk= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lucas-clemente/quic-go v0.26.0 h1:ALBQXr9UJ8A1LyzvceX4jd9QFsHvlI0RR6BkV16o00A= +github.com/lucas-clemente/quic-go v0.26.0/go.mod h1:AzgQoPda7N+3IqMMMkywBKggIFo2KT6pfnlrQ2QieeI= +github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= +github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/marten-seemann/qpack v0.2.1/go.mod h1:F7Gl5L1jIgN1D11ucXefiuJS9UMVP2opoCp2jDKb7wc= +github.com/marten-seemann/qtls-go1-16 v0.1.5 h1:o9JrYPPco/Nukd/HpOHMHZoBDXQqoNtUCmny98/1uqQ= +github.com/marten-seemann/qtls-go1-16 v0.1.5/go.mod h1:gNpI2Ol+lRS3WwSOtIUUtRwZEQMXjYK+dQSBFbethAk= +github.com/marten-seemann/qtls-go1-17 v0.1.1 h1:DQjHPq+aOzUeh9/lixAGunn6rIOQyWChPSI4+hgW7jc= +github.com/marten-seemann/qtls-go1-17 v0.1.1/go.mod h1:C2ekUKcDdz9SDWxec1N/MvcXBpaX9l3Nx67XaR84L5s= +github.com/marten-seemann/qtls-go1-18 v0.1.1 h1:qp7p7XXUFL7fpBvSS1sWD+uSqPvzNQK43DH+/qEkj0Y= +github.com/marten-seemann/qtls-go1-18 v0.1.1/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= +github.com/miekg/dns v1.1.43 h1:JKfpVSCB84vrAmHzyrsxB5NAr5kLoMXZArPSw7Qlgyg= +github.com/miekg/dns v1.1.43/go.mod h1:+evo5L0630/F6ca/Z9+GAqzhjGyn8/c+TBaOyfEl0V4= +github.com/milosgajdos/tenus v0.0.3 h1:jmaJzwaY1DUyYVD0lM4U+uvP2kkEg1VahDqRFxIkVBE= +github.com/milosgajdos/tenus v0.0.3/go.mod h1:eIjx29vNeDOYWJuCnaHY2r4fq5egetV26ry3on7p8qY= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJEU3ofeGjhHklVoIGuVj85JJwZ6kWPaJwCIxgnFmo= +github.com/neelance/sourcemap v0.0.0-20151028013722-8c68805598ab/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= +github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= +github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= +github.com/onsi/ginkgo v1.16.2/go.mod h1:CObGmKUOKaSC0RjmoAK7tKyn4Azo5P2IWuoMnvwxz1E= +github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= +github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/onsi/gomega v1.13.0 h1:7lLHu94wT9Ij0o6EWWclhu0aOh32VxhkwEJvzuWPeak= +github.com/onsi/gomega v1.13.0/go.mod h1:lRk9szgn8TxENtWd0Tp4c3wjlRfMTMH27I+3Je41yGY= +github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= +github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3 h1:f/FNXud6gA3MNr8meMVVGxhp+QBTqY91tM8HjEuMjGg= +github.com/riobard/go-bloom v0.0.0-20200614022211-cdc8013cb5b3/go.mod h1:HgjTstvQsPGkxUsCd2KWxErBblirPizecHcpD3ffK+s= +github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= +github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= +github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= +github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= +github.com/shadowsocks/go-shadowsocks2 v0.1.5 h1:PDSQv9y2S85Fl7VBeOMF9StzeXZyK1HakRm86CUbr28= +github.com/shadowsocks/go-shadowsocks2 v0.1.5/go.mod h1:AGGpIoek4HRno4xzyFiAtLHkOpcoznZEkAccaI/rplM= +github.com/shadowsocks/shadowsocks-go v0.0.0-20200409064450-3e585ff90601 h1:XU9hik0exChEmY92ALW4l9WnDodxLVS9yOSNh2SizaQ= +github.com/shadowsocks/shadowsocks-go v0.0.0-20200409064450-3e585ff90601/go.mod h1:mttDPaeLm87u74HMrP+n2tugXvIKWcwff/cqSX0lehY= +github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= +github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM= +github.com/shurcooL/github_flavored_markdown v0.0.0-20181002035957-2122de532470/go.mod h1:2dOwnU2uBioM+SGy2aZoq1f/Sd1l9OkAeAUvjSyvgU0= +github.com/shurcooL/go v0.0.0-20180423040247-9e1955d9fb6e/go.mod h1:TDJrrUr11Vxrven61rcy3hJMUqaf/CLWYhHNPmT14Lk= +github.com/shurcooL/go-goon v0.0.0-20170922171312-37c2f522c041/go.mod h1:N5mDOmsrJOB+vfqUK+7DmDyjhSLIIBnXo9lvZJj3MWQ= +github.com/shurcooL/gofontwoff v0.0.0-20180329035133-29b52fc0a18d/go.mod h1:05UtEgK5zq39gLST6uB0cf3NEHjETfB4Fgr3Gx5R9Vw= +github.com/shurcooL/gopherjslib v0.0.0-20160914041154-feb6d3990c2c/go.mod h1:8d3azKNyqcHP1GaQE/c6dDgjkgSx2BZ4IoEi4F1reUI= +github.com/shurcooL/highlight_diff v0.0.0-20170515013008-09bb4053de1b/go.mod h1:ZpfEhSmds4ytuByIcDnOLkTHGUI6KNqRNPDLHDk+mUU= +github.com/shurcooL/highlight_go v0.0.0-20181028180052-98c3abbbae20/go.mod h1:UDKB5a1T23gOMUJrI+uSuH0VRDStOiUVSjBTRDVBVag= +github.com/shurcooL/home v0.0.0-20181020052607-80b7ffcb30f9/go.mod h1:+rgNQw2P9ARFAs37qieuu7ohDNQ3gds9msbT2yn85sg= +github.com/shurcooL/htmlg v0.0.0-20170918183704-d01228ac9e50/go.mod h1:zPn1wHpTIePGnXSHpsVPWEktKXHr6+SS6x/IKRb7cpw= +github.com/shurcooL/httperror v0.0.0-20170206035902-86b7830d14cc/go.mod h1:aYMfkZ6DWSJPJ6c4Wwz3QtW22G7mf/PEgaB9k/ik5+Y= +github.com/shurcooL/httpfs v0.0.0-20171119174359-809beceb2371/go.mod h1:ZY1cvUeJuFPAdZ/B6v7RHavJWZn2YPVFQ1OSXhCGOkg= +github.com/shurcooL/httpgzip v0.0.0-20180522190206-b1c53ac65af9/go.mod h1:919LwcH0M7/W4fcZ0/jy0qGght1GIhqyS/EgWGH2j5Q= +github.com/shurcooL/issues v0.0.0-20181008053335-6292fdc1e191/go.mod h1:e2qWDig5bLteJ4fwvDAc2NHzqFEthkqn7aOZAOpj+PQ= +github.com/shurcooL/issuesapp v0.0.0-20180602232740-048589ce2241/go.mod h1:NPpHK2TI7iSaM0buivtFUc9offApnI0Alt/K8hcHy0I= +github.com/shurcooL/notifications v0.0.0-20181007000457-627ab5aea122/go.mod h1:b5uSkrEVM1jQUspwbixRBhaIjIzL2xazXp6kntxYle0= +github.com/shurcooL/octicon v0.0.0-20181028054416-fa4f57f9efb2/go.mod h1:eWdoE5JD4R5UVWDucdOPg1g2fqQRq78IQa9zlOV1vpQ= +github.com/shurcooL/reactions v0.0.0-20181006231557-f2e0b4ca5b82/go.mod h1:TCR1lToEk4d2s07G3XGfz2QrgHXg4RJBvjrOozvoWfk= +github.com/shurcooL/sanitized_anchor_name v0.0.0-20170918181015-86672fcb3f95/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/shurcooL/users v0.0.0-20180125191416-49c67e49c537/go.mod h1:QJTqeLYEDaXHZDBsXlPCDqdhQuJkuw4NOtaxYe3xii4= +github.com/shurcooL/webdavfs v0.0.0-20170829043945-18c3829fa133/go.mod h1:hKmq5kWdCj2z2KEozexVbfEZIWiTjhE0+UjmZgPqehw= +github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= +github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= +github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= +github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= +github.com/templexxx/cpufeat v0.0.0-20180724012125-cef66df7f161 h1:89CEmDvlq/F7SJEOqkIdNDGJXrQIhuIx9D2DBXjavSU= +github.com/templexxx/cpufeat v0.0.0-20180724012125-cef66df7f161/go.mod h1:wM7WEvslTq+iOEAMDLSzhVuOt5BRZ05WirO+b09GHQU= +github.com/templexxx/xor v0.0.0-20191217153810-f85b25db303b h1:fj5tQ8acgNUr6O8LEplsxDhUIe2573iLkJc+PqnzZTI= +github.com/templexxx/xor v0.0.0-20191217153810-f85b25db303b/go.mod h1:5XA7W9S6mni3h5uvOC75dA3m9CCCaS83lltmc0ukdi4= +github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho= +github.com/tjfoc/gmsm v1.4.1/go.mod h1:j4INPkHWMrhJb38G+J6W4Tw0AbuN8Thu3PbdVYhVcTE= +github.com/ulikunitz/xz v0.5.6/go.mod h1:2bypXElzHzzJZwzH67Y6wb67pO62Rzfn7BSiF4ABRW8= +github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= +github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= +github.com/xtaci/kcp-go v5.4.20+incompatible h1:TN1uey3Raw0sTz0Fg8GkfM0uH3YwzhnZWQ1bABv5xAg= +github.com/xtaci/kcp-go v5.4.20+incompatible/go.mod h1:bN6vIwHQbfHaHtFpEssmWsN45a+AZwO7eyRCmEIbtvE= +github.com/xtaci/lossyconn v0.0.0-20200209145036-adba10fffc37 h1:EWU6Pktpas0n8lLQwDsRyZfmkPeRbdgPtW609es+/9E= +github.com/xtaci/lossyconn v0.0.0-20200209145036-adba10fffc37/go.mod h1:HpMP7DB2CyokmAh4lp0EQnnWhmycP/TvwBGzvuie+H0= +github.com/xtaci/smux v1.5.16 h1:FBPYOkW8ZTjLKUM4LI4xnnuuDC8CQ/dB04HD519WoEk= +github.com/xtaci/smux v1.5.16/go.mod h1:OMlQbT5vcgl2gb49mFkYo6SMf+zP3rcjcwQz7ZU7IGY= +github.com/xtaci/tcpraw v1.2.25 h1:VDlqo0op17JeXBM6e2G9ocCNLOJcw9mZbobMbJjo0vk= +github.com/xtaci/tcpraw v1.2.25/go.mod h1:dKyZ2V75s0cZ7cbgJYdxPvms7af0joIeOyx1GgJQbLk= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +gitlab.com/yawning/bsaes.git v0.0.0-20190805113838-0a714cd429ec/go.mod h1:BZ1RAoRPbCxum9Grlv5aeksu2H8BiKehBYooU2LFiOQ= +gitlab.com/yawning/obfs4.git v0.0.0-20210511220700-e330d1b7024b h1:w/f20IHUkUYEp+xYgpKz4Bs78zms0DbjPZCep5lc0xA= +gitlab.com/yawning/obfs4.git v0.0.0-20210511220700-e330d1b7024b/go.mod h1:OM1ngEp5brdANPox+rqk2AGTLQvzobyB5Dwm3vu3CgM= +gitlab.com/yawning/utls.git v0.0.12-1/go.mod h1:3ONKiSFR9Im/c3t5RKmMJTVdmZN496FNyk3mjrY1dyo= +go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= +go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= +golang.org/x/build v0.0.0-20190111050920-041ab4dc3f9d/go.mod h1:OWs+y06UdEOHN4y+MfF/py+xQ/tYqIWW03b70/CG9Rw= +golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20220321153916-2c7772ba3064 h1:S25/rfnfsMVgORT4/J61MJ7rdyseOZOyvLIrZEZ7s6s= +golang.org/x/crypto v0.0.0-20220321153916-2c7772ba3064/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 h1:kQgndtyPBW/JIYERgdxfwMYh3AVStj88WQTlNDi2a+o= +golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181029044818-c44066c5c816/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190313220215-9f648a60d977/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190328230028-74de082e2cca/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= +golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220325170049-de3da57026de h1:pZB1TWnKi+o4bENlbzAgLrEbY4RMYmUIRobMcSmfeYc= +golang.org/x/net v0.0.0-20220325170049-de3da57026de/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190316082340-a2f829d7f35f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190804053845-51ab0e2deafa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220325203850-36772127a21f h1:TrmogKRsSOxRMJbLYGrB4SBbW+LJcEllYBLME5Zk5pU= +golang.org/x/sys v0.0.0-20220325203850-36772127a21f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.10 h1:QjFRCZxdOhBJ/UNgnBZLbNV13DlbnK0quyivTnXJM20= +golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= +google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0= +google.golang.org/api v0.1.0/go.mod h1:UGEZY7KEX120AnNLIHFMKIo4obdJhkp2tPbaPlQx13Y= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20181029155118-b69ba1387ce2/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20181202183823-bd91e49a0898/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg= +google.golang.org/genproto v0.0.0-20190306203927-b5d61aea6440/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= +google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= +google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= +honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +sourcegraph.com/sourcegraph/go-diff v0.5.0/go.mod h1:kuch7UrkMzY0X+p9CRK03kfuPQ2zzQcaEFbx8wA8rck= +sourcegraph.com/sqs/pbtypes v0.0.0-20180604144634-d3ebe8f20ae4/go.mod h1:ketZ/q3QxT9HOBeFhu6RdvsftgpsbFHBF5Cas6cDKZ0= diff --git a/gost.go b/gost.go new file mode 100644 index 0000000..8d39945 --- /dev/null +++ b/gost.go @@ -0,0 +1,216 @@ +package gost + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "io" + "math/big" + "net" + "net/http" + "strings" + "sync" + "time" + + "github.com/go-log/log" +) + +// Version is the gost version. +const Version = "2.11.2-EvanMod-v1.1" + +// Debug is a flag that enables the debug log. +var Debug bool + +var ( + tinyBufferSize = 512 + smallBufferSize = 2 * 1024 // 2KB small buffer + mediumBufferSize = 8 * 1024 // 8KB medium buffer + largeBufferSize = 32 * 1024 // 32KB large buffer +) + +var ( + sPool = sync.Pool{ + New: func() interface{} { + return make([]byte, smallBufferSize) + }, + } + mPool = sync.Pool{ + New: func() interface{} { + return make([]byte, mediumBufferSize) + }, + } + lPool = sync.Pool{ + New: func() interface{} { + return make([]byte, largeBufferSize) + }, + } +) + +var ( + // KeepAliveTime is the keep alive time period for TCP connection. + KeepAliveTime = 180 * time.Second + // DialTimeout is the timeout of dial. + DialTimeout = 5 * time.Second + // HandshakeTimeout is the timeout of handshake. + HandshakeTimeout = 5 * time.Second + // ConnectTimeout is the timeout for connect. + ConnectTimeout = 5 * time.Second + // ReadTimeout is the timeout for reading. + ReadTimeout = 10 * time.Second + // WriteTimeout is the timeout for writing. + WriteTimeout = 10 * time.Second + // PingTimeout is the timeout for pinging. + PingTimeout = 30 * time.Second + // PingRetries is the reties of ping. + PingRetries = 1 + // default udp node TTL in second for udp port forwarding. + defaultTTL = 60 * time.Second + defaultBacklog = 128 + defaultQueueSize = 128 +) + +var ( + // DefaultTLSConfig is a default TLS config for internal use. + DefaultTLSConfig *tls.Config + + // DefaultUserAgent is the default HTTP User-Agent header used by HTTP and websocket. + DefaultUserAgent = "Chrome/102.0.0.0" + + // DefaultMTU is the default mtu for tun/tap device + DefaultMTU = 1350 +) + +// SetLogger sets a new logger for internal log system. +func SetLogger(logger log.Logger) { + log.DefaultLogger = logger +} + +// GenCertificate generates a random TLS certificate. +func GenCertificate() (cert tls.Certificate, err error) { + rawCert, rawKey, err := generateKeyPair() + if err != nil { + return + } + return tls.X509KeyPair(rawCert, rawKey) +} + +func generateKeyPair() (rawCert, rawKey []byte, err error) { + // Create private key and self-signed certificate + // Adapted from https://golang.org/src/crypto/tls/generate_cert.go + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return + } + validFor := time.Hour * 24 * 365 * 10 // ten years + notBefore := time.Now() + notAfter := notBefore.Add(validFor) + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"gost"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return + } + + rawCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + rawKey = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + + return +} + +type readWriter struct { + r io.Reader + w io.Writer +} + +func (rw *readWriter) Read(p []byte) (n int, err error) { + return rw.r.Read(p) +} + +func (rw *readWriter) Write(p []byte) (n int, err error) { + return rw.w.Write(p) +} + +var ( + nopClientConn = &nopConn{} +) + +// a nop connection implements net.Conn, +// it does nothing. +type nopConn struct{} + +func (c *nopConn) Read(b []byte) (n int, err error) { + return 0, &net.OpError{Op: "read", Net: "nop", Source: nil, Addr: nil, Err: errors.New("read not supported")} +} + +func (c *nopConn) Write(b []byte) (n int, err error) { + return 0, &net.OpError{Op: "write", Net: "nop", Source: nil, Addr: nil, Err: errors.New("write not supported")} +} + +func (c *nopConn) Close() error { + return nil +} + +func (c *nopConn) LocalAddr() net.Addr { + return nil +} + +func (c *nopConn) RemoteAddr() net.Addr { + return nil +} + +func (c *nopConn) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *nopConn) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *nopConn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +// splitLine splits a line text by white space, mainly used by config parser. +func splitLine(line string) []string { + if line == "" { + return nil + } + if n := strings.IndexByte(line, '#'); n >= 0 { + line = line[:n] + } + line = strings.Replace(line, "\t", " ", -1) + line = strings.TrimSpace(line) + + var ss []string + for _, s := range strings.Split(line, " ") { + if s = strings.TrimSpace(s); s != "" { + ss = append(ss, s) + } + } + return ss +} + +func connStateCallback(conn net.Conn, cs http.ConnState) { + switch cs { + case http.StateNew: + conn.SetReadDeadline(time.Now().Add(30 * time.Second)) + default: + } +} diff --git a/handler.go b/handler.go new file mode 100644 index 0000000..eee515c --- /dev/null +++ b/handler.go @@ -0,0 +1,270 @@ +package gost + +import ( + "bufio" + "crypto/tls" + "net" + "net/url" + "time" + + "github.com/go-gost/gosocks4" + "github.com/go-gost/gosocks5" + "github.com/go-log/log" +) + +// Handler is a proxy server handler +type Handler interface { + Init(options ...HandlerOption) + Handle(net.Conn) +} + +// HandlerOptions describes the options for Handler. +type HandlerOptions struct { + Addr string + Chain *Chain + Users []*url.Userinfo + Authenticator Authenticator + TLSConfig *tls.Config + Whitelist *Permissions + Blacklist *Permissions + Strategy Strategy + MaxFails int + FailTimeout time.Duration + Bypass *Bypass + Retries int + Timeout time.Duration + Resolver Resolver + Hosts *Hosts + ProbeResist string + KnockingHost string + Node Node + Host string + IPs []string + TCPMode bool + IPRoutes []IPRoute +} + +// HandlerOption allows a common way to set handler options. +type HandlerOption func(opts *HandlerOptions) + +// AddrHandlerOption sets the Addr option of HandlerOptions. +func AddrHandlerOption(addr string) HandlerOption { + return func(opts *HandlerOptions) { + opts.Addr = addr + } +} + +// ChainHandlerOption sets the Chain option of HandlerOptions. +func ChainHandlerOption(chain *Chain) HandlerOption { + return func(opts *HandlerOptions) { + opts.Chain = chain + } +} + +// UsersHandlerOption sets the Users option of HandlerOptions. +func UsersHandlerOption(users ...*url.Userinfo) HandlerOption { + return func(opts *HandlerOptions) { + opts.Users = users + + kvs := make(map[string]string) + for _, u := range users { + if u != nil { + kvs[u.Username()], _ = u.Password() + } + } + if len(kvs) > 0 { + opts.Authenticator = NewLocalAuthenticator(kvs) + } + } +} + +// AuthenticatorHandlerOption sets the Authenticator option of HandlerOptions. +func AuthenticatorHandlerOption(au Authenticator) HandlerOption { + return func(opts *HandlerOptions) { + opts.Authenticator = au + } +} + +// TLSConfigHandlerOption sets the TLSConfig option of HandlerOptions. +func TLSConfigHandlerOption(config *tls.Config) HandlerOption { + return func(opts *HandlerOptions) { + opts.TLSConfig = config + } +} + +// WhitelistHandlerOption sets the Whitelist option of HandlerOptions. +func WhitelistHandlerOption(whitelist *Permissions) HandlerOption { + return func(opts *HandlerOptions) { + opts.Whitelist = whitelist + } +} + +// BlacklistHandlerOption sets the Blacklist option of HandlerOptions. +func BlacklistHandlerOption(blacklist *Permissions) HandlerOption { + return func(opts *HandlerOptions) { + opts.Blacklist = blacklist + } +} + +// BypassHandlerOption sets the bypass option of HandlerOptions. +func BypassHandlerOption(bypass *Bypass) HandlerOption { + return func(opts *HandlerOptions) { + opts.Bypass = bypass + } +} + +// StrategyHandlerOption sets the strategy option of HandlerOptions. +func StrategyHandlerOption(strategy Strategy) HandlerOption { + return func(opts *HandlerOptions) { + opts.Strategy = strategy + } +} + +// MaxFailsHandlerOption sets the max_fails option of HandlerOptions. +func MaxFailsHandlerOption(n int) HandlerOption { + return func(opts *HandlerOptions) { + opts.MaxFails = n + } +} + +// FailTimeoutHandlerOption sets the fail_timeout option of HandlerOptions. +func FailTimeoutHandlerOption(d time.Duration) HandlerOption { + return func(opts *HandlerOptions) { + opts.FailTimeout = d + } +} + +// RetryHandlerOption sets the retry option of HandlerOptions. +func RetryHandlerOption(retries int) HandlerOption { + return func(opts *HandlerOptions) { + opts.Retries = retries + } +} + +// TimeoutHandlerOption sets the timeout option of HandlerOptions. +func TimeoutHandlerOption(timeout time.Duration) HandlerOption { + return func(opts *HandlerOptions) { + opts.Timeout = timeout + } +} + +// ResolverHandlerOption sets the resolver option of HandlerOptions. +func ResolverHandlerOption(resolver Resolver) HandlerOption { + return func(opts *HandlerOptions) { + opts.Resolver = resolver + } +} + +// HostsHandlerOption sets the Hosts option of HandlerOptions. +func HostsHandlerOption(hosts *Hosts) HandlerOption { + return func(opts *HandlerOptions) { + opts.Hosts = hosts + } +} + +// ProbeResistHandlerOption adds the probe resistance for HTTP proxy. +func ProbeResistHandlerOption(pr string) HandlerOption { + return func(opts *HandlerOptions) { + opts.ProbeResist = pr + } +} + +// KnockingHandlerOption adds the knocking host for probe resistance. +func KnockingHandlerOption(host string) HandlerOption { + return func(opts *HandlerOptions) { + opts.KnockingHost = host + } +} + +// NodeHandlerOption set the server node for server handler. +func NodeHandlerOption(node Node) HandlerOption { + return func(opts *HandlerOptions) { + opts.Node = node + } +} + +// HostHandlerOption sets the target host for SNI proxy. +func HostHandlerOption(host string) HandlerOption { + return func(opts *HandlerOptions) { + opts.Host = host + } +} + +// IPsHandlerOption sets the ip list for port forward. +func IPsHandlerOption(ips []string) HandlerOption { + return func(opts *HandlerOptions) { + opts.IPs = ips + } +} + +// TCPModeHandlerOption sets the tcp mode for tun/tap device. +func TCPModeHandlerOption(b bool) HandlerOption { + return func(opts *HandlerOptions) { + opts.TCPMode = b + } +} + +// IPRoutesHandlerOption sets the IP routes for tun tunnel. +func IPRoutesHandlerOption(routes ...IPRoute) HandlerOption { + return func(opts *HandlerOptions) { + opts.IPRoutes = routes + } +} + +type autoHandler struct { + options *HandlerOptions +} + +// AutoHandler creates a server Handler for auto proxy server. +func AutoHandler(opts ...HandlerOption) Handler { + h := &autoHandler{} + h.Init(opts...) + return h +} + +func (h *autoHandler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} + } + for _, opt := range options { + opt(h.options) + } +} + +func (h *autoHandler) Handle(conn net.Conn) { + br := bufio.NewReader(conn) + b, err := br.Peek(1) + if err != nil { + log.Logf("[auto] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err) + conn.Close() + return + } + + cc := &bufferdConn{Conn: conn, br: br} + var handler Handler + switch b[0] { + case gosocks4.Ver4: + // SOCKS4(a) does not suppport authentication method, + // so we ignore it when credentials are specified for security reason. + if len(h.options.Users) > 0 { + cc.Close() + return + } + handler = &socks4Handler{options: h.options} + case gosocks5.Ver5: // socks5 + handler = &socks5Handler{options: h.options} + default: // http + handler = &httpHandler{options: h.options} + } + handler.Init() + handler.Handle(cc) +} + +type bufferdConn struct { + net.Conn + br *bufio.Reader +} + +func (c *bufferdConn) Read(b []byte) (int, error) { + return c.br.Read(b) +} diff --git a/handler_test.go b/handler_test.go new file mode 100644 index 0000000..d412d4c --- /dev/null +++ b/handler_test.go @@ -0,0 +1,228 @@ +package gost + +import ( + "crypto/rand" + "crypto/tls" + "net/http/httptest" + "net/url" + "testing" +) + +func autoHTTPProxyRoundtrip(targetURL string, data []byte, clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + ln, err := TCPListener("") + if err != nil { + return err + } + + client := &Client{ + Connector: HTTPConnector(clientInfo), + Transporter: TCPTransporter(), + } + server := &Server{ + Listener: ln, + Handler: AutoHandler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestAutoHTTPProxy(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range httpProxyTests { + err := autoHTTPProxyRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) + if err == nil { + if tc.errStr != "" { + t.Errorf("#%d should failed with error %s", i, tc.errStr) + } + } else { + if tc.errStr == "" { + t.Errorf("#%d got error %v", i, err) + } + if err.Error() != tc.errStr { + t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) + } + } + } +} + +func autoSocks5ProxyRoundtrip(targetURL string, data []byte, clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + ln, err := TCPListener("") + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS5Connector(clientInfo), + Transporter: TCPTransporter(), + } + + server := &Server{ + Handler: AutoHandler(UsersHandlerOption(serverInfo...)), + Listener: ln, + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestAutoSOCKS5Proxy(t *testing.T) { + cert, err := GenCertificate() + if err != nil { + panic(err) + } + DefaultTLSConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range socks5ProxyTests { + err := autoSocks5ProxyRoundtrip(httpSrv.URL, sendData, + tc.cliUser, + tc.srvUsers, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func autoSOCKS4ProxyRoundtrip(targetURL string, data []byte, options ...HandlerOption) error { + ln, err := TCPListener("") + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4Connector(), + Transporter: TCPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: AutoHandler(options...), + } + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestAutoSOCKS4Proxy(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + if err := autoSOCKS4ProxyRoundtrip(httpSrv.URL, sendData); err != nil { + t.Errorf("got error: %v", err) + } + + if err := autoSOCKS4ProxyRoundtrip(httpSrv.URL, sendData, + UsersHandlerOption(url.UserPassword("admin", "123456"))); err == nil { + t.Errorf("authentication required auto handler for SOCKS4 should failed") + } +} + +func autoSocks4aProxyRoundtrip(targetURL string, data []byte, options ...HandlerOption) error { + ln, err := TCPListener("") + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4AConnector(), + Transporter: TCPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: AutoHandler(options...), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestAutoSOCKS4AProxy(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + if err := autoSocks4aProxyRoundtrip(httpSrv.URL, sendData); err != nil { + t.Errorf("got error: %v", err) + } + + if err := autoSocks4aProxyRoundtrip(httpSrv.URL, sendData, + UsersHandlerOption(url.UserPassword("admin", "123456"))); err == nil { + t.Errorf("authentication required auto handler for SOCKS4A should failed") + } +} + +func autoSSProxyRoundtrip(targetURL string, data []byte, clientInfo *url.Userinfo, serverInfo *url.Userinfo) error { + ln, err := TCPListener("") + if err != nil { + return err + } + + client := &Client{ + Connector: ShadowConnector(clientInfo), + Transporter: TCPTransporter(), + } + + server := &Server{ + Handler: AutoHandler(UsersHandlerOption(serverInfo)), + Listener: ln, + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestAutoSSProxy(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range ssTests { + err := autoSSProxyRoundtrip(httpSrv.URL, sendData, + tc.clientCipher, + tc.serverCipher, + ) + if err == nil { + t.Errorf("#%d should failed", i) + } + } +} diff --git a/hosts.go b/hosts.go new file mode 100644 index 0000000..6df0325 --- /dev/null +++ b/hosts.go @@ -0,0 +1,160 @@ +package gost + +import ( + "bufio" + "io" + "net" + "sync" + "time" + + "github.com/go-log/log" +) + +// Host is a static mapping from hostname to IP. +type Host struct { + IP net.IP + Hostname string + Aliases []string +} + +// NewHost creates a Host. +func NewHost(ip net.IP, hostname string, aliases ...string) Host { + return Host{ + IP: ip, + Hostname: hostname, + Aliases: aliases, + } +} + +// Hosts is a static table lookup for hostnames. +// For each host a single line should be present with the following information: +// IP_address canonical_hostname [aliases...] +// Fields of the entry are separated by any number of blanks and/or tab characters. +// Text from a "#" character until the end of the line is a comment, and is ignored. +type Hosts struct { + hosts []Host + period time.Duration + stopped chan struct{} + mux sync.RWMutex +} + +// NewHosts creates a Hosts with optional list of hosts. +func NewHosts(hosts ...Host) *Hosts { + return &Hosts{ + hosts: hosts, + stopped: make(chan struct{}), + } +} + +// AddHost adds host(s) to the host table. +func (h *Hosts) AddHost(host ...Host) { + h.mux.Lock() + defer h.mux.Unlock() + + h.hosts = append(h.hosts, host...) +} + +// Lookup searches the IP address corresponds to the given host from the host table. +func (h *Hosts) Lookup(host string) (ip net.IP) { + if h == nil || host == "" { + return + } + + h.mux.RLock() + defer h.mux.RUnlock() + + for _, h := range h.hosts { + if h.Hostname == host { + ip = h.IP + break + } + for _, alias := range h.Aliases { + if alias == host { + ip = h.IP + break + } + } + } + if ip != nil && Debug { + log.Logf("[hosts] hit: %s %s", host, ip.String()) + } + return +} + +// Reload parses config from r, then live reloads the hosts. +func (h *Hosts) Reload(r io.Reader) error { + var period time.Duration + var hosts []Host + + if r == nil || h.Stopped() { + return nil + } + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + ss := splitLine(line) + if len(ss) < 2 { + continue // invalid lines are ignored + } + + switch ss[0] { + case "reload": // reload option + period, _ = time.ParseDuration(ss[1]) + default: + ip := net.ParseIP(ss[0]) + if ip == nil { + break // invalid IP addresses are ignored + } + host := Host{ + IP: ip, + Hostname: ss[1], + } + if len(ss) > 2 { + host.Aliases = ss[2:] + } + hosts = append(hosts, host) + } + } + if err := scanner.Err(); err != nil { + return err + } + + h.mux.Lock() + h.period = period + h.hosts = hosts + h.mux.Unlock() + + return nil +} + +// Period returns the reload period +func (h *Hosts) Period() time.Duration { + if h.Stopped() { + return -1 + } + + h.mux.RLock() + defer h.mux.RUnlock() + + return h.period +} + +// Stop stops reloading. +func (h *Hosts) Stop() { + select { + case <-h.stopped: + default: + close(h.stopped) + } +} + +// Stopped checks whether the reloader is stopped. +func (h *Hosts) Stopped() bool { + select { + case <-h.stopped: + return true + default: + return false + } +} diff --git a/hosts_test.go b/hosts_test.go new file mode 100644 index 0000000..2fbae32 --- /dev/null +++ b/hosts_test.go @@ -0,0 +1,130 @@ +package gost + +import ( + "bytes" + "io" + "net" + "testing" + "time" +) + +var hostsLookupTests = []struct { + hosts []Host + host string + ip net.IP +}{ + {nil, "", nil}, + {nil, "example.com", nil}, + {[]Host{}, "", nil}, + {[]Host{}, "example.com", nil}, + {[]Host{NewHost(nil, "")}, "", nil}, + {[]Host{NewHost(nil, "example.com")}, "example.com", nil}, + {[]Host{NewHost(net.IPv4(192, 168, 1, 1), "")}, "", nil}, + {[]Host{NewHost(net.IPv4(192, 168, 1, 1), "example.com")}, "example.com", net.IPv4(192, 168, 1, 1)}, + {[]Host{NewHost(net.IPv4(192, 168, 1, 1), "example.com")}, "example", nil}, + {[]Host{NewHost(net.IPv4(192, 168, 1, 1), "example.com", "example", "examples")}, "example", net.IPv4(192, 168, 1, 1)}, + {[]Host{NewHost(net.IPv4(192, 168, 1, 1), "example.com", "example", "examples")}, "examples", net.IPv4(192, 168, 1, 1)}, +} + +func TestHostsLookup(t *testing.T) { + for i, tc := range hostsLookupTests { + hosts := NewHosts() + hosts.AddHost(tc.hosts...) + ip := hosts.Lookup(tc.host) + if !ip.Equal(tc.ip) { + t.Errorf("#%d test failed: lookup should be %s, got %s", i, tc.ip, ip) + } + } +} + +var HostsReloadTests = []struct { + r io.Reader + period time.Duration + host string + ip net.IP + stopped bool +}{ + { + r: nil, + period: 0, + host: "", + ip: nil, + }, + { + r: bytes.NewBufferString(""), + period: 0, + host: "example.com", + ip: nil, + }, + { + r: bytes.NewBufferString("reload 10s"), + period: 10 * time.Second, + host: "example.com", + ip: nil, + }, + { + r: bytes.NewBufferString("#reload 10s\ninvalid.ip.addr example.com"), + period: 0, + ip: nil, + }, + { + r: bytes.NewBufferString("reload 10s\n192.168.1.1"), + period: 10 * time.Second, + host: "", + ip: nil, + }, + { + r: bytes.NewBufferString("#reload 10s\n192.168.1.1 example.com"), + period: 0, + host: "example.com", + ip: net.IPv4(192, 168, 1, 1), + }, + { + r: bytes.NewBufferString("#reload 10s\n#192.168.1.1 example.com"), + period: 0, + host: "example.com", + ip: nil, + stopped: true, + }, + { + r: bytes.NewBufferString("#reload 10s\n192.168.1.1 example.com example examples"), + period: 0, + host: "example", + ip: net.IPv4(192, 168, 1, 1), + stopped: true, + }, + { + r: bytes.NewBufferString("#reload 10s\n192.168.1.1 example.com example examples"), + period: 0, + host: "examples", + ip: net.IPv4(192, 168, 1, 1), + stopped: true, + }, +} + +func TestHostsReload(t *testing.T) { + for i, tc := range HostsReloadTests { + hosts := NewHosts() + if err := hosts.Reload(tc.r); err != nil { + t.Error(err) + } + if hosts.Period() != tc.period { + t.Errorf("#%d test failed: period value should be %v, got %v", + i, tc.period, hosts.Period()) + } + ip := hosts.Lookup(tc.host) + if !ip.Equal(tc.ip) { + t.Errorf("#%d test failed: lookup should be %s, got %s", i, tc.ip, ip) + } + if tc.stopped { + hosts.Stop() + if hosts.Period() >= 0 { + t.Errorf("period of the stopped reloader should be minus value") + } + } + if hosts.Stopped() != tc.stopped { + t.Errorf("#%d test failed: stopped value should be %v, got %v", + i, tc.stopped, hosts.Stopped()) + } + } +} diff --git a/http.go b/http.go new file mode 100644 index 0000000..d001793 --- /dev/null +++ b/http.go @@ -0,0 +1,483 @@ +package gost + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "fmt" + "net" + "net/http" + "net/http/httputil" + "net/url" + "os" + "strconv" + "strings" + "time" + + "github.com/go-log/log" +) + +type httpConnector struct { + User *url.Userinfo +} + +// HTTPConnector creates a Connector for HTTP proxy client. +// It accepts an optional auth info for HTTP Basic Authentication. +func HTTPConnector(user *url.Userinfo) Connector { + return &httpConnector{User: user} +} + +func (c *httpConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + +func (c *httpConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + return nil, fmt.Errorf("%s unsupported", network) + } + + opts := &ConnectOptions{} + for _, option := range options { + option(opts) + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = ConnectTimeout + } + ua := opts.UserAgent + if ua == "" { + ua = DefaultUserAgent + } + + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + + req := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Host: address}, + Host: address, + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + } + req.Header.Set("User-Agent", ua) + req.Header.Set("Proxy-Connection", "keep-alive") + + user := opts.User + if user == nil { + user = c.User + } + + if user != nil { + u := user.Username() + p, _ := user.Password() + req.Header.Set("Proxy-Authorization", + "Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p))) + } + + //Process Header + for k, v := range opts.HeaderConfig { + if len(k) > 2 && k[0:2] == "--" { + req.Header.Del(k[2:]) + continue + } + req.Header.Set(k, v) + } + + if err := req.Write(conn); err != nil { + return nil, err + } + + if Debug { + dump, _ := httputil.DumpRequest(req, false) + log.Log(string(dump)) + } + + resp, err := http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + return nil, err + } + + if Debug { + dump, _ := httputil.DumpResponse(resp, false) + log.Log(string(dump)) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%s", resp.Status) + } + + return conn, nil +} + +type httpHandler struct { + options *HandlerOptions +} + +// HTTPHandler creates a server Handler for HTTP proxy server. +func HTTPHandler(opts ...HandlerOption) Handler { + h := &httpHandler{} + h.Init(opts...) + return h +} + +func (h *httpHandler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} + } + for _, opt := range options { + opt(h.options) + } +} + +func (h *httpHandler) Handle(conn net.Conn) { + defer conn.Close() + + req, err := http.ReadRequest(bufio.NewReader(conn)) + if err != nil { + log.Logf("[http] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + defer req.Body.Close() + + h.handleRequest(conn, req) +} + +func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) { + if req == nil { + return + } + + // try to get the actual host. + if v := req.Header.Get("Gost-Target"); v != "" { + if h, err := decodeServerName(v); err == nil { + req.Host = h + } + } + + host := req.Host + if _, port, _ := net.SplitHostPort(host); port == "" { + host = net.JoinHostPort(host, "80") + } + + u, _, _ := basicProxyAuth(req.Header.Get("Proxy-Authorization")) + if u != "" { + u += "@" + } + log.Logf("[http] %s%s -> %s -> %s", + u, conn.RemoteAddr(), h.options.Node.String(), host) + + if Debug { + dump, _ := httputil.DumpRequest(req, false) + log.Logf("[http] %s -> %s\n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump)) + } + + req.Header.Del("Gost-Target") + + resp := &http.Response{ + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{}, + } + resp.Header.Add("Proxy-Agent", "gost/"+Version) + + if !Can("tcp", host, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[http] %s - %s : Unauthorized to tcp connect to %s", + conn.RemoteAddr(), conn.LocalAddr(), host) + resp.StatusCode = http.StatusForbidden + + if Debug { + dump, _ := httputil.DumpResponse(resp, false) + log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump)) + } + + resp.Write(conn) + return + } + + if h.options.Bypass.Contains(host) { + resp.StatusCode = http.StatusForbidden + + log.Logf("[http] %s - %s bypass %s", + conn.RemoteAddr(), conn.LocalAddr(), host) + if Debug { + dump, _ := httputil.DumpResponse(resp, false) + log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump)) + } + + resp.Write(conn) + return + } + + if !h.authenticate(conn, req, resp) { + return + } + + if req.Method == "PRI" || (req.Method != http.MethodConnect && req.URL.Scheme != "http") { + resp.StatusCode = http.StatusBadRequest + + if Debug { + dump, _ := httputil.DumpResponse(resp, false) + log.Logf("[http] %s <- %s\n%s", + conn.RemoteAddr(), conn.LocalAddr(), string(dump)) + } + + resp.Write(conn) + return + } + + req.Header.Del("Proxy-Authorization") + + retries := 1 + if h.options.Chain != nil && h.options.Chain.Retries > 0 { + retries = h.options.Chain.Retries + } + if h.options.Retries > 0 { + retries = h.options.Retries + } + + var err error + var cc net.Conn + var route *Chain + for i := 0; i < retries; i++ { + route, err = h.options.Chain.selectRouteFor(host) + if err != nil { + log.Logf("[http] %s -> %s : %s", + conn.RemoteAddr(), conn.LocalAddr(), err) + continue + } + + buf := bytes.Buffer{} + fmt.Fprintf(&buf, "%s -> %s -> ", + conn.RemoteAddr(), h.options.Node.String()) + for _, nd := range route.route { + fmt.Fprintf(&buf, "%d@%s -> ", nd.ID, nd.String()) + } + fmt.Fprintf(&buf, "%s", host) + log.Log("[route]", buf.String()) + + // forward http request + lastNode := route.LastNode() + if req.Method != http.MethodConnect && lastNode.Protocol == "http" { + err = h.forwardRequest(conn, req, route) + if err == nil { + return + } + log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + continue + } + + cc, err = route.Dial(host, + TimeoutChainOption(h.options.Timeout), + HostsChainOption(h.options.Hosts), + ResolverChainOption(h.options.Resolver), + ) + if err == nil { + break + } + log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + } + + if err != nil { + resp.StatusCode = http.StatusServiceUnavailable + + if Debug { + dump, _ := httputil.DumpResponse(resp, false) + log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump)) + } + + resp.Write(conn) + return + } + defer cc.Close() + + if req.Method == http.MethodConnect { + b := []byte("HTTP/1.1 200 Connection established\r\n" + + "Proxy-Agent: gost/" + Version + "\r\n\r\n") + if Debug { + log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), conn.LocalAddr(), string(b)) + } + conn.Write(b) + } else { + req.Header.Del("Proxy-Connection") + + if err = req.Write(cc); err != nil { + log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + } + + log.Logf("[http] %s <-> %s", conn.RemoteAddr(), host) + transport(conn, cc) + log.Logf("[http] %s >-< %s", conn.RemoteAddr(), host) +} + +func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.Response) (ok bool) { + u, p, _ := basicProxyAuth(req.Header.Get("Proxy-Authorization")) + if Debug && (u != "" || p != "") { + log.Logf("[http] %s -> %s : Authorization '%s' '%s'", + conn.RemoteAddr(), conn.LocalAddr(), u, p) + } + if h.options.Authenticator == nil || h.options.Authenticator.Authenticate(u, p) { + return true + } + + // probing resistance is enabled, and knocking host is mismatch. + if ss := strings.SplitN(h.options.ProbeResist, ":", 2); len(ss) == 2 && + (h.options.KnockingHost == "" || !strings.EqualFold(req.URL.Hostname(), h.options.KnockingHost)) { + resp.StatusCode = http.StatusServiceUnavailable // default status code + + switch ss[0] { + case "code": + resp.StatusCode, _ = strconv.Atoi(ss[1]) + case "web": + url := ss[1] + if !strings.HasPrefix(url, "http") { + url = "http://" + url + } + if r, err := http.Get(url); err == nil { + resp = r + } + case "host": + cc, err := net.Dial("tcp", ss[1]) + if err == nil { + defer cc.Close() + + req.Write(cc) + log.Logf("[http] %s <-> %s : forward to %s", + conn.RemoteAddr(), conn.LocalAddr(), ss[1]) + transport(conn, cc) + log.Logf("[http] %s >-< %s : forward to %s", + conn.RemoteAddr(), conn.LocalAddr(), ss[1]) + return + } + case "file": + f, _ := os.Open(ss[1]) + if f != nil { + resp.StatusCode = http.StatusOK + if finfo, _ := f.Stat(); finfo != nil { + resp.ContentLength = finfo.Size() + } + resp.Header.Set("Content-Type", "text/html") + resp.Body = f + } + } + } + + if resp.StatusCode == 0 { + log.Logf("[http] %s <- %s : proxy authentication required", + conn.RemoteAddr(), conn.LocalAddr()) + resp.StatusCode = http.StatusProxyAuthRequired + resp.Header.Add("Proxy-Authenticate", "Basic realm=\"gost\"") + if strings.ToLower(req.Header.Get("Proxy-Connection")) == "keep-alive" { + // XXX libcurl will keep sending auth request in same conn + // which we don't supported yet. + resp.Header.Add("Connection", "close") + resp.Header.Add("Proxy-Connection", "close") + } + } else { + resp.Header = http.Header{} + resp.Header.Set("Server", "nginx/1.14.1") + resp.Header.Set("Date", time.Now().Format(http.TimeFormat)) + if resp.StatusCode == http.StatusOK { + resp.Header.Set("Connection", "keep-alive") + } + } + + if Debug { + dump, _ := httputil.DumpResponse(resp, false) + log.Logf("[http] %s <- %s\n%s", + conn.RemoteAddr(), conn.LocalAddr(), string(dump)) + } + + resp.Write(conn) + return +} + +func (h *httpHandler) forwardRequest(conn net.Conn, req *http.Request, route *Chain) error { + if route.IsEmpty() { + return nil + } + + host := req.Host + var userpass string + + if user := route.LastNode().User; user != nil { + u := user.Username() + p, _ := user.Password() + userpass = base64.StdEncoding.EncodeToString([]byte(u + ":" + p)) + } + + cc, err := route.Conn() + if err != nil { + return err + } + defer cc.Close() + + errc := make(chan error, 1) + go func() { + errc <- copyBuffer(conn, cc) + }() + + go func() { + for { + if userpass != "" { + req.Header.Set("Proxy-Authorization", "Basic "+userpass) + } + + cc.SetWriteDeadline(time.Now().Add(WriteTimeout)) + if !req.URL.IsAbs() { + req.URL.Scheme = "http" // make sure that the URL is absolute + } + err := req.WriteProxy(cc) + if err != nil { + log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + errc <- err + return + } + cc.SetWriteDeadline(time.Time{}) + + req, err = http.ReadRequest(bufio.NewReader(conn)) + if err != nil { + errc <- err + return + } + + if Debug { + dump, _ := httputil.DumpRequest(req, false) + log.Logf("[http] %s -> %s\n%s", + conn.RemoteAddr(), conn.LocalAddr(), string(dump)) + } + } + }() + + log.Logf("[http] %s <-> %s", conn.RemoteAddr(), host) + <-errc + log.Logf("[http] %s >-< %s", conn.RemoteAddr(), host) + + return nil +} + +func basicProxyAuth(proxyAuth string) (username, password string, ok bool) { + if proxyAuth == "" { + return + } + + if !strings.HasPrefix(proxyAuth, "Basic ") { + return + } + c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(proxyAuth, "Basic ")) + if err != nil { + return + } + cs := string(c) + s := strings.IndexByte(cs, ':') + if s < 0 { + return + } + + return cs[:s], cs[s+1:], true +} diff --git a/http2.go b/http2.go new file mode 100644 index 0000000..8c675bc --- /dev/null +++ b/http2.go @@ -0,0 +1,972 @@ +package gost + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/http/httputil" + "net/url" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/go-log/log" + "golang.org/x/net/http2" +) + +type http2Connector struct { + User *url.Userinfo +} + +// HTTP2Connector creates a Connector for HTTP2 proxy client. +// It accepts an optional auth info for HTTP Basic Authentication. +func HTTP2Connector(user *url.Userinfo) Connector { + return &http2Connector{User: user} +} + +func (c *http2Connector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + +func (c *http2Connector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + return nil, fmt.Errorf("%s unsupported", network) + } + + opts := &ConnectOptions{} + for _, option := range options { + option(opts) + } + ua := opts.UserAgent + if ua == "" { + ua = DefaultUserAgent + } + + cc, ok := conn.(*http2ClientConn) + if !ok { + return nil, errors.New("wrong connection type") + } + + pr, pw := io.Pipe() + req := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Scheme: "https", Host: cc.addr}, + Header: make(http.Header), + Proto: "HTTP/2.0", + ProtoMajor: 2, + ProtoMinor: 0, + Body: pr, + Host: address, + ContentLength: -1, + } + req.Header.Set("User-Agent", ua) + + user := opts.User + if user == nil { + user = c.User + } + + if user != nil { + u := user.Username() + p, _ := user.Password() + req.Header.Set("Proxy-Authorization", + "Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p))) + } + if Debug { + dump, _ := httputil.DumpRequest(req, false) + log.Log("[http2]", string(dump)) + } + resp, err := cc.client.Do(req) + if err != nil { + cc.Close() + return nil, err + } + if Debug { + dump, _ := httputil.DumpResponse(resp, false) + log.Log("[http2]", string(dump)) + } + + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, errors.New(resp.Status) + } + hc := &http2Conn{ + r: resp.Body, + w: pw, + closed: make(chan struct{}), + } + + hc.remoteAddr, _ = net.ResolveTCPAddr("tcp", address) + hc.localAddr, _ = net.ResolveTCPAddr("tcp", cc.addr) + + return hc, nil +} + +type http2Transporter struct { + clients map[string]*http.Client + clientMutex sync.Mutex + tlsConfig *tls.Config +} + +// HTTP2Transporter creates a Transporter that is used by HTTP2 h2 proxy client. +func HTTP2Transporter(config *tls.Config) Transporter { + if config == nil { + config = &tls.Config{InsecureSkipVerify: true} + } + return &http2Transporter{ + clients: make(map[string]*http.Client), + tlsConfig: config, + } +} + +func (tr *http2Transporter) Dial(addr string, options ...DialOption) (net.Conn, error) { + opts := &DialOptions{} + for _, option := range options { + option(opts) + } + + tr.clientMutex.Lock() + defer tr.clientMutex.Unlock() + + client, ok := tr.clients[addr] + if !ok { + // NOTE: There is no real connection to the HTTP2 server at this moment. + // So we try to connect to the server to check the server health. + conn, err := opts.Chain.Dial(addr) + if err != nil { + log.Log("http2 dial:", addr, err) + return nil, err + } + conn.Close() + + timeout := opts.Timeout + if timeout <= 0 { + timeout = DialTimeout + } + transport := http2.Transport{ + TLSClientConfig: tr.tlsConfig, + DialTLS: func(network, adr string, cfg *tls.Config) (net.Conn, error) { + conn, err := opts.Chain.Dial(adr) + if err != nil { + return nil, err + } + return wrapTLSClient(conn, cfg, timeout) + }, + } + client = &http.Client{ + Transport: &transport, + // Timeout: timeout, + } + tr.clients[addr] = client + } + + return &http2ClientConn{ + addr: addr, + client: client, + onClose: func() { + tr.clientMutex.Lock() + defer tr.clientMutex.Unlock() + delete(tr.clients, addr) + }, + }, nil +} + +func (tr *http2Transporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + return conn, nil +} + +func (tr *http2Transporter) Multiplex() bool { + return true +} + +// TODO: clean closed clients +type h2Transporter struct { + clients map[string]*http.Client + clientMutex sync.Mutex + tlsConfig *tls.Config + path string +} + +// H2Transporter creates a Transporter that is used by HTTP2 h2 tunnel client. +func H2Transporter(config *tls.Config, path string) Transporter { + if config == nil { + config = &tls.Config{InsecureSkipVerify: true} + } + return &h2Transporter{ + clients: make(map[string]*http.Client), + tlsConfig: config, + path: path, + } +} + +// H2CTransporter creates a Transporter that is used by HTTP2 h2c tunnel client. +func H2CTransporter(path string) Transporter { + return &h2Transporter{ + clients: make(map[string]*http.Client), + path: path, + } +} + +func (tr *h2Transporter) Dial(addr string, options ...DialOption) (net.Conn, error) { + opts := &DialOptions{} + for _, option := range options { + option(opts) + } + + tr.clientMutex.Lock() + client, ok := tr.clients[addr] + if !ok { + timeout := opts.Timeout + if timeout <= 0 { + timeout = DialTimeout + } + + transport := http2.Transport{ + TLSClientConfig: tr.tlsConfig, + DialTLS: func(network, adr string, cfg *tls.Config) (net.Conn, error) { + conn, err := opts.Chain.Dial(addr) + if err != nil { + return nil, err + } + if tr.tlsConfig == nil { + return conn, nil + } + return wrapTLSClient(conn, cfg, timeout) + }, + } + client = &http.Client{ + Transport: &transport, + // Timeout: timeout, + } + tr.clients[addr] = client + } + tr.clientMutex.Unlock() + + pr, pw := io.Pipe() + req := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Scheme: "https", Host: opts.Host}, + Header: make(http.Header), + Proto: "HTTP/2.0", + ProtoMajor: 2, + ProtoMinor: 0, + Body: pr, + Host: opts.Host, + ContentLength: -1, + } + if tr.path != "" { + req.Method = http.MethodGet + req.URL.Path = tr.path + } + //Process Header + for k, v := range opts.HeaderConfig { + if len(k) > 2 && k[0:2] == "--" { + req.Header.Del(k[2:]) + continue + } + req.Header.Set(k, v) + } + if Debug { + dump, _ := httputil.DumpRequest(req, false) + log.Log("[http2]", string(dump)) + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + if Debug { + dump, _ := httputil.DumpResponse(resp, false) + log.Log("[http2]", string(dump)) + } + + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, errors.New(resp.Status) + } + conn := &http2Conn{ + r: resp.Body, + w: pw, + closed: make(chan struct{}), + } + conn.remoteAddr, _ = net.ResolveTCPAddr("tcp", addr) + conn.localAddr = &net.TCPAddr{IP: net.IPv4zero, Port: 0} + return conn, nil +} + +func (tr *h2Transporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + return conn, nil +} + +func (tr *h2Transporter) Multiplex() bool { + return true +} + +type http2Handler struct { + options *HandlerOptions +} + +// HTTP2Handler creates a server Handler for HTTP2 proxy server. +func HTTP2Handler(opts ...HandlerOption) Handler { + h := &http2Handler{} + h.Init(opts...) + + return h +} + +func (h *http2Handler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} + } + for _, opt := range options { + opt(h.options) + } +} + +func (h *http2Handler) Handle(conn net.Conn) { + defer conn.Close() + + h2c, ok := conn.(*http2ServerConn) + if !ok { + log.Log("[http2] wrong connection type") + return + } + + h.roundTrip(h2c.w, h2c.r) +} + +func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) { + host := r.Header.Get("Gost-Target") + if host == "" { + host = r.Host + } + + if _, port, _ := net.SplitHostPort(host); port == "" { + host = net.JoinHostPort(host, "80") + } + + laddr := h.options.Addr + u, _, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization")) + if u != "" { + u += "@" + } + log.Logf("[http2] %s%s -> %s -> %s", + u, r.RemoteAddr, h.options.Node.String(), host) + + if Debug { + dump, _ := httputil.DumpRequest(r, false) + log.Logf("[http2] %s - %s\n%s", r.RemoteAddr, laddr, string(dump)) + } + + w.Header().Set("Proxy-Agent", "gost/"+Version) + + if !Can("tcp", host, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[http2] %s - %s : Unauthorized to tcp connect to %s", + r.RemoteAddr, laddr, host) + w.WriteHeader(http.StatusForbidden) + return + } + + if h.options.Bypass.Contains(host) { + log.Logf("[http2] %s - %s bypass %s", + r.RemoteAddr, laddr, host) + w.WriteHeader(http.StatusForbidden) + return + } + + resp := &http.Response{ + ProtoMajor: 2, + ProtoMinor: 0, + Header: http.Header{}, + Body: ioutil.NopCloser(bytes.NewReader([]byte{})), + } + + if !h.authenticate(w, r, resp) { + return + } + + // delete the proxy related headers. + r.Header.Del("Proxy-Authorization") + r.Header.Del("Proxy-Connection") + + retries := 1 + if h.options.Chain != nil && h.options.Chain.Retries > 0 { + retries = h.options.Chain.Retries + } + if h.options.Retries > 0 { + retries = h.options.Retries + } + + var err error + var cc net.Conn + var route *Chain + for i := 0; i < retries; i++ { + route, err = h.options.Chain.selectRouteFor(host) + if err != nil { + log.Logf("[http2] %s -> %s : %s", + r.RemoteAddr, laddr, err) + continue + } + + buf := bytes.Buffer{} + fmt.Fprintf(&buf, "%s -> %s -> ", + r.RemoteAddr, h.options.Node.String()) + for _, nd := range route.route { + fmt.Fprintf(&buf, "%d@%s -> ", nd.ID, nd.String()) + } + fmt.Fprintf(&buf, "%s", host) + log.Log("[route]", buf.String()) + + cc, err = route.Dial(host, + TimeoutChainOption(h.options.Timeout), + HostsChainOption(h.options.Hosts), + ResolverChainOption(h.options.Resolver), + ) + if err == nil { + break + } + log.Logf("[http2] %s -> %s : %s", r.RemoteAddr, laddr, err) + } + + if err != nil { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + defer cc.Close() + + if r.Method == http.MethodConnect { + w.WriteHeader(http.StatusOK) + if fw, ok := w.(http.Flusher); ok { + fw.Flush() + } + + // compatible with HTTP1.x + if hj, ok := w.(http.Hijacker); ok && r.ProtoMajor == 1 { + // we take over the underly connection + conn, _, err := hj.Hijack() + if err != nil { + log.Logf("[http2] %s -> %s : %s", + r.RemoteAddr, laddr, err) + w.WriteHeader(http.StatusInternalServerError) + return + } + defer conn.Close() + + log.Logf("[http2] %s <-> %s : downgrade to HTTP/1.1", r.RemoteAddr, host) + transport(conn, cc) + log.Logf("[http2] %s >-< %s", r.RemoteAddr, host) + return + } + + log.Logf("[http2] %s <-> %s", r.RemoteAddr, host) + transport(&readWriter{r: r.Body, w: flushWriter{w}}, cc) + log.Logf("[http2] %s >-< %s", r.RemoteAddr, host) + return + } + + log.Logf("[http2] %s <-> %s", r.RemoteAddr, host) + if err := h.forwardRequest(w, r, cc); err != nil { + log.Logf("[http2] %s - %s : %s", r.RemoteAddr, host, err) + } + log.Logf("[http2] %s >-< %s", r.RemoteAddr, host) +} + +func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp *http.Response) (ok bool) { + laddr := h.options.Addr + u, p, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization")) + if Debug && (u != "" || p != "") { + log.Logf("[http2] %s - %s : Authorization '%s' '%s'", r.RemoteAddr, laddr, u, p) + } + if h.options.Authenticator == nil || h.options.Authenticator.Authenticate(u, p) { + return true + } + + // probing resistance is enabled, and knocking host is mismatch. + if ss := strings.SplitN(h.options.ProbeResist, ":", 2); len(ss) == 2 && + (h.options.KnockingHost == "" || !strings.EqualFold(r.URL.Hostname(), h.options.KnockingHost)) { + resp.StatusCode = http.StatusServiceUnavailable // default status code + w.Header().Del("Proxy-Agent") + + switch ss[0] { + case "code": + resp.StatusCode, _ = strconv.Atoi(ss[1]) + case "web": + url := ss[1] + if !strings.HasPrefix(url, "http") { + url = "http://" + url + } + if r, err := http.Get(url); err == nil { + resp = r + } + case "host": + cc, err := net.Dial("tcp", ss[1]) + if err == nil { + defer cc.Close() + log.Logf("[http2] %s <-> %s : forward to %s", r.RemoteAddr, laddr, ss[1]) + if err := h.forwardRequest(w, r, cc); err != nil { + log.Logf("[http2] %s - %s : %s", r.RemoteAddr, laddr, err) + } + log.Logf("[http2] %s >-< %s : forward to %s", r.RemoteAddr, laddr, ss[1]) + return + } + case "file": + f, _ := os.Open(ss[1]) + if f != nil { + resp.StatusCode = http.StatusOK + if finfo, _ := f.Stat(); finfo != nil { + resp.ContentLength = finfo.Size() + } + resp.Body = f + } + } + } + + if resp.StatusCode == 0 { + log.Logf("[http2] %s <- %s : proxy authentication required", r.RemoteAddr, laddr) + resp.StatusCode = http.StatusProxyAuthRequired + resp.Header.Add("Proxy-Authenticate", "Basic realm=\"gost\"") + } else { + resp.Header = http.Header{} + resp.Header.Set("Server", "nginx/1.14.1") + resp.Header.Set("Date", time.Now().Format(http.TimeFormat)) + if resp.ContentLength > 0 { + resp.Header.Set("Content-Type", "text/html") + } + if resp.StatusCode == http.StatusOK { + resp.Header.Set("Connection", "keep-alive") + } + } + + if Debug { + dump, _ := httputil.DumpResponse(resp, false) + log.Logf("[http2] %s <- %s\n%s", r.RemoteAddr, laddr, string(dump)) + } + + h.writeResponse(w, resp) + resp.Body.Close() + + return +} + +func (h *http2Handler) forwardRequest(w http.ResponseWriter, r *http.Request, rw io.ReadWriter) (err error) { + if err = r.Write(rw); err != nil { + return + } + + resp, err := http.ReadResponse(bufio.NewReader(rw), r) + if err != nil { + return + } + defer resp.Body.Close() + + return h.writeResponse(w, resp) +} + +func (h *http2Handler) writeResponse(w http.ResponseWriter, resp *http.Response) error { + for k, v := range resp.Header { + for _, vv := range v { + w.Header().Add(k, vv) + } + } + w.WriteHeader(resp.StatusCode) + _, err := io.Copy(flushWriter{w}, resp.Body) + return err +} + +type http2Listener struct { + server *http.Server + connChan chan *http2ServerConn + addr net.Addr + errChan chan error +} + +// HTTP2Listener creates a Listener for HTTP2 proxy server. +func HTTP2Listener(addr string, config *tls.Config) (Listener, error) { + l := &http2Listener{ + connChan: make(chan *http2ServerConn, 1024), + errChan: make(chan error, 1), + } + if config == nil { + config = DefaultTLSConfig + } + server := &http.Server{ + Addr: addr, + Handler: http.HandlerFunc(l.handleFunc), + TLSConfig: config, + } + if err := http2.ConfigureServer(server, nil); err != nil { + return nil, err + } + l.server = server + + ln, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + l.addr = ln.Addr() + + ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config) + go func() { + err := server.Serve(ln) + if err != nil { + log.Log("[http2]", err) + } + }() + + return l, nil +} + +func (l *http2Listener) handleFunc(w http.ResponseWriter, r *http.Request) { + conn := &http2ServerConn{ + r: r, + w: w, + closed: make(chan struct{}), + } + select { + case l.connChan <- conn: + default: + log.Logf("[http2] %s - %s: connection queue is full", r.RemoteAddr, l.server.Addr) + return + } + + <-conn.closed +} + +func (l *http2Listener) Accept() (conn net.Conn, err error) { + select { + case conn = <-l.connChan: + case err = <-l.errChan: + if err == nil { + err = errors.New("accpet on closed listener") + } + } + return +} + +func (l *http2Listener) Addr() net.Addr { + return l.addr +} + +func (l *http2Listener) Close() (err error) { + select { + case <-l.errChan: + default: + err = l.server.Close() + l.errChan <- err + close(l.errChan) + } + return nil +} + +type h2Listener struct { + net.Listener + server *http2.Server + tlsConfig *tls.Config + path string + connChan chan net.Conn + errChan chan error +} + +// H2Listener creates a Listener for HTTP2 h2 tunnel server. +func H2Listener(addr string, config *tls.Config, path string) (Listener, error) { + ln, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + if config == nil { + config = DefaultTLSConfig + } + + l := &h2Listener{ + Listener: tcpKeepAliveListener{ln.(*net.TCPListener)}, + server: &http2.Server{ + // MaxConcurrentStreams: 1000, + PermitProhibitedCipherSuites: true, + IdleTimeout: 5 * time.Minute, + }, + tlsConfig: config, + path: path, + connChan: make(chan net.Conn, 1024), + errChan: make(chan error, 1), + } + go l.listenLoop() + + return l, nil +} + +// H2CListener creates a Listener for HTTP2 h2c tunnel server. +func H2CListener(addr string, path string) (Listener, error) { + ln, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + l := &h2Listener{ + Listener: tcpKeepAliveListener{ln.(*net.TCPListener)}, + server: &http2.Server{ + // MaxConcurrentStreams: 1000, + }, + path: path, + connChan: make(chan net.Conn, 1024), + errChan: make(chan error, 1), + } + go l.listenLoop() + + return l, nil +} + +func (l *h2Listener) listenLoop() { + for { + conn, err := l.Listener.Accept() + if err != nil { + log.Log("[http2] accept:", err) + l.errChan <- err + close(l.errChan) + return + } + go l.handleLoop(conn) + } +} + +func (l *h2Listener) handleLoop(conn net.Conn) { + if l.tlsConfig != nil { + conn = tls.Server(conn, l.tlsConfig) + } + + if tc, ok := conn.(*tls.Conn); ok { + // NOTE: HTTP2 server will check the TLS version, + // so we must ensure that the TLS connection is handshake completed. + if err := tc.Handshake(); err != nil { + log.Logf("[http2] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + } + + opt := http2.ServeConnOpts{ + Handler: http.HandlerFunc(l.handleFunc), + } + l.server.ServeConn(conn, &opt) +} + +func (l *h2Listener) handleFunc(w http.ResponseWriter, r *http.Request) { + log.Logf("[http2] %s -> %s %s %s %s", + r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto) + if Debug { + dump, _ := httputil.DumpRequest(r, false) + log.Log("[http2]", string(dump)) + } + w.Header().Set("Proxy-Agent", "gost/"+Version) + conn, err := l.upgrade(w, r) + if err != nil { + log.Logf("[http2] %s - %s %s %s %s: %s", + r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto, err) + return + } + select { + case l.connChan <- conn: + default: + conn.Close() + log.Logf("[http2] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr()) + } + + <-conn.closed // NOTE: we need to wait for streaming end, or the connection will be closed +} + +func (l *h2Listener) upgrade(w http.ResponseWriter, r *http.Request) (*http2Conn, error) { + if l.path == "" && r.Method != http.MethodConnect { + w.WriteHeader(http.StatusMethodNotAllowed) + return nil, errors.New("method not allowed") + } + + if l.path != "" && r.RequestURI != l.path { + w.WriteHeader(http.StatusBadRequest) + return nil, errors.New("bad request") + } + + w.WriteHeader(http.StatusOK) + if fw, ok := w.(http.Flusher); ok { + fw.Flush() // write header to client + } + + remoteAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr) + if remoteAddr == nil { + remoteAddr = &net.TCPAddr{ + IP: net.IPv4zero, + Port: 0, + } + } + conn := &http2Conn{ + r: r.Body, + w: flushWriter{w}, + localAddr: l.Listener.Addr(), + remoteAddr: remoteAddr, + closed: make(chan struct{}), + } + return conn, nil +} + +func (l *h2Listener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.connChan: + case err, ok = <-l.errChan: + if !ok { + err = errors.New("accpet on closed listener") + } + } + return +} + +// HTTP2 connection, wrapped up just like a net.Conn +type http2Conn struct { + r io.Reader + w io.Writer + remoteAddr net.Addr + localAddr net.Addr + closed chan struct{} +} + +func (c *http2Conn) Read(b []byte) (n int, err error) { + return c.r.Read(b) +} + +func (c *http2Conn) Write(b []byte) (n int, err error) { + return c.w.Write(b) +} + +func (c *http2Conn) Close() (err error) { + select { + case <-c.closed: + return + default: + close(c.closed) + } + if rc, ok := c.r.(io.Closer); ok { + err = rc.Close() + } + if w, ok := c.w.(io.Closer); ok { + err = w.Close() + } + return +} + +func (c *http2Conn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *http2Conn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *http2Conn) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *http2Conn) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *http2Conn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +// a dummy HTTP2 server conn used by HTTP2 handler +type http2ServerConn struct { + r *http.Request + w http.ResponseWriter + closed chan struct{} +} + +func (c *http2ServerConn) Read(b []byte) (n int, err error) { + return 0, &net.OpError{Op: "read", Net: "http2", Source: nil, Addr: nil, Err: errors.New("read not supported")} +} + +func (c *http2ServerConn) Write(b []byte) (n int, err error) { + return 0, &net.OpError{Op: "write", Net: "http2", Source: nil, Addr: nil, Err: errors.New("write not supported")} +} + +func (c *http2ServerConn) Close() error { + select { + case <-c.closed: + default: + close(c.closed) + } + return nil +} + +func (c *http2ServerConn) LocalAddr() net.Addr { + addr, _ := net.ResolveTCPAddr("tcp", c.r.Host) + return addr +} + +func (c *http2ServerConn) RemoteAddr() net.Addr { + addr, _ := net.ResolveTCPAddr("tcp", c.r.RemoteAddr) + return addr +} + +func (c *http2ServerConn) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *http2ServerConn) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *http2ServerConn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +// a dummy HTTP2 client conn used by HTTP2 client connector +type http2ClientConn struct { + nopConn + addr string + client *http.Client + onClose func() +} + +func (c *http2ClientConn) Close() error { + if c.onClose != nil { + c.onClose() + } + return nil +} + +type flushWriter struct { + w io.Writer +} + +func (fw flushWriter) Write(p []byte) (n int, err error) { + defer func() { + if r := recover(); r != nil { + if s, ok := r.(string); ok { + err = errors.New(s) + log.Log("[http2]", err) + return + } + err = r.(error) + } + }() + + n, err = fw.w.Write(p) + if err != nil { + // log.Log("flush writer:", err) + return + } + if f, ok := fw.w.(http.Flusher); ok { + f.Flush() + } + return +} diff --git a/http2_test.go b/http2_test.go new file mode 100644 index 0000000..762c6df --- /dev/null +++ b/http2_test.go @@ -0,0 +1,1151 @@ +package gost + +import ( + "bytes" + "crypto/rand" + "crypto/tls" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func http2ProxyRoundtrip(targetURL string, data []byte, clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + ln, err := HTTP2Listener("", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: HTTP2Connector(clientInfo), + Transporter: HTTP2Transporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: HTTP2Handler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestHTTP2ProxyAuth(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range httpProxyTests { + err := http2ProxyRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) + if err == nil { + if tc.errStr != "" { + t.Errorf("#%d should failed with error %s", i, tc.errStr) + } + } else { + if tc.errStr == "" { + t.Errorf("#%d got error %v", i, err) + } + if err.Error() != tc.errStr { + t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) + } + } + } +} + +func BenchmarkHTTP2Proxy(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := HTTP2Listener("", nil) + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTP2Connector(url.UserPassword("admin", "123456")), + Transporter: HTTP2Transporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: HTTP2Handler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkHTTP2ProxyParallel(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := HTTP2Listener("", nil) + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTP2Connector(url.UserPassword("admin", "123456")), + Transporter: HTTP2Transporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: HTTP2Handler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } + }) +} + +func httpOverH2Roundtrip(targetURL string, data []byte, tlsConfig *tls.Config, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := H2Listener("", tlsConfig, "/h2") + if err != nil { + return err + } + + client := &Client{ + Connector: HTTPConnector(clientInfo), + Transporter: H2Transporter(nil, "/h2"), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestHTTPOverH2(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range httpProxyTests { + err := httpOverH2Roundtrip(httpSrv.URL, sendData, nil, tc.cliUser, tc.srvUsers) + if err == nil { + if tc.errStr != "" { + t.Errorf("#%d should failed with error %s", i, tc.errStr) + } + } else { + if tc.errStr == "" { + t.Errorf("#%d got error %v", i, err) + } + if err.Error() != tc.errStr { + t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) + } + } + } +} + +func BenchmarkHTTPOverH2(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := H2Listener("", nil, "/h2") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: H2Transporter(nil, "/h2"), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkHTTPOverH2Parallel(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := H2Listener("", nil, "/h2") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: H2Transporter(nil, "/h2"), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } + }) +} + +func socks5OverH2Roundtrip(targetURL string, data []byte, tlsConfig *tls.Config, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := H2Listener("", tlsConfig, "/h2") + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS5Connector(clientInfo), + Transporter: H2Transporter(nil, "/h2"), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS5Handler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS5OverH2(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range socks5ProxyTests { + err := socks5OverH2Roundtrip(httpSrv.URL, sendData, + nil, + tc.cliUser, + tc.srvUsers, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func socks4OverH2Roundtrip(targetURL string, data []byte, tlsConfig *tls.Config) error { + ln, err := H2Listener("", tlsConfig, "/h2") + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4Connector(), + Transporter: H2Transporter(nil, "/h2"), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4OverH2(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4OverH2Roundtrip(httpSrv.URL, sendData, nil) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func socks4aOverH2Roundtrip(targetURL string, data []byte, tlsConfig *tls.Config) error { + ln, err := H2Listener("", tlsConfig, "/h2") + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4AConnector(), + Transporter: H2Transporter(nil, "/h2"), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4AOverH2(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4aOverH2Roundtrip(httpSrv.URL, sendData, nil) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func ssOverH2Roundtrip(targetURL string, data []byte, tlsConfig *tls.Config, + clientInfo, serverInfo *url.Userinfo) error { + + ln, err := H2Listener("", tlsConfig, "/h2") + if err != nil { + return err + } + + client := &Client{ + Connector: ShadowConnector(clientInfo), + Transporter: H2Transporter(nil, "/h2"), + } + + server := &Server{ + Listener: ln, + Handler: ShadowHandler( + UsersHandlerOption(serverInfo), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSSOverH2(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range ssProxyTests { + err := ssOverH2Roundtrip(httpSrv.URL, sendData, + nil, + tc.clientCipher, + tc.serverCipher, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func sniOverH2Roundtrip(targetURL string, data []byte, host string) error { + ln, err := H2Listener("", nil, "/h2") + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: H2Transporter(nil, "/h2"), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverH2(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverH2Roundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} + +func h2ForwardTunnelRoundtrip(targetURL string, data []byte) error { + ln, err := H2Listener("", nil, "") + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: H2Transporter(nil, ""), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + server.Handler.Init() + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestH2ForwardTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := h2ForwardTunnelRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} + +func httpOverH2CRoundtrip(targetURL string, data []byte, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := H2CListener("", "/h2c") + if err != nil { + return err + } + + client := &Client{ + Connector: HTTPConnector(clientInfo), + Transporter: H2CTransporter("/h2c"), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestHTTPOverH2C(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range httpProxyTests { + err := httpOverH2CRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) + if err == nil { + if tc.errStr != "" { + t.Errorf("#%d should failed with error %s", i, tc.errStr) + } + } else { + if tc.errStr == "" { + t.Errorf("#%d got error %v", i, err) + } + if err.Error() != tc.errStr { + t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) + } + } + } +} + +func BenchmarkHTTPOverH2C(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := H2CListener("", "/h2c") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: H2CTransporter("/h2c"), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkHTTPOverH2CParallel(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := H2CListener("", "/h2c") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: H2CTransporter("/h2c"), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } + }) +} + +func socks5OverH2CRoundtrip(targetURL string, data []byte, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := H2CListener("", "/h2c") + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS5Connector(clientInfo), + Transporter: H2CTransporter("/h2c"), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS5Handler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS5OverH2C(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range socks5ProxyTests { + err := socks5OverH2CRoundtrip(httpSrv.URL, sendData, + tc.cliUser, + tc.srvUsers, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func socks4OverH2CRoundtrip(targetURL string, data []byte) error { + ln, err := H2CListener("", "/h2c") + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4Connector(), + Transporter: H2CTransporter("/h2c"), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4OverH2C(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4OverH2CRoundtrip(httpSrv.URL, sendData) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func socks4aOverH2CRoundtrip(targetURL string, data []byte) error { + ln, err := H2CListener("", "/h2c") + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4AConnector(), + Transporter: H2CTransporter("/h2c"), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4AOverH2C(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4aOverH2CRoundtrip(httpSrv.URL, sendData) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func ssOverH2CRoundtrip(targetURL string, data []byte, + clientInfo, serverInfo *url.Userinfo) error { + + ln, err := H2CListener("", "/h2c") + if err != nil { + return err + } + + client := &Client{ + Connector: ShadowConnector(clientInfo), + Transporter: H2CTransporter("/h2c"), + } + + server := &Server{ + Listener: ln, + Handler: ShadowHandler( + UsersHandlerOption(serverInfo), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSSOverH2C(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range ssProxyTests { + err := ssOverH2CRoundtrip(httpSrv.URL, sendData, + tc.clientCipher, + tc.serverCipher, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func sniOverH2CRoundtrip(targetURL string, data []byte, host string) error { + ln, err := H2CListener("", "/h2c") + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: H2CTransporter("/h2c"), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverH2C(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverH2CRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} + +func h2cForwardTunnelRoundtrip(targetURL string, data []byte) error { + ln, err := H2CListener("", "") + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: H2CTransporter(""), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + server.Handler.Init() + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestH2CForwardTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := h2cForwardTunnelRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} + +func TestHTTP2ProxyWithCodeProbeResist(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + ln, err := HTTP2Listener("", nil) + if err != nil { + t.Error(err) + } + + client := &Client{ + Connector: HTTP2Connector(nil), + Transporter: HTTP2Transporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: HTTP2Handler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ProbeResistHandlerOption("code:400"), + ), + } + go server.Run() + defer server.Close() + + err = proxyRoundtrip(client, server, httpSrv.URL, nil) + if err == nil { + t.Error("should failed with status code 400") + } else if err.Error() != "400 Bad Request" { + t.Error("should failed with status code 400, got", err.Error()) + } +} + +func TestHTTP2ProxyWithWebProbeResist(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + ln, err := HTTP2Listener("", nil) + if err != nil { + t.Error(err) + } + + client := &Client{ + Connector: HTTP2Connector(nil), + Transporter: HTTP2Transporter(nil), + } + + u, err := url.Parse(httpSrv.URL) + if err != nil { + t.Error(err) + } + server := &Server{ + Listener: ln, + Handler: HTTP2Handler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ProbeResistHandlerOption("web:"+u.Host), + ), + } + go server.Run() + defer server.Close() + + conn, err := proxyConn(client, server) + if err != nil { + t.Error(err) + } + defer conn.Close() + + conn, err = client.Connect(conn, "github.com:443") + if err != nil { + t.Error(err) + } + recv, _ := ioutil.ReadAll(conn) + if !bytes.Equal(recv, []byte("Hello World!")) { + t.Error("data not equal") + } +} + +func TestHTTP2ProxyWithHostProbeResist(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := HTTP2Listener("", nil) + if err != nil { + t.Error(err) + } + + client := &Client{ + Connector: HTTP2Connector(nil), + Transporter: HTTP2Transporter(nil), + } + + u, err := url.Parse(httpSrv.URL) + if err != nil { + t.Error(err) + } + + server := &Server{ + Listener: ln, + Handler: HTTP2Handler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ProbeResistHandlerOption("host:"+u.Host), + ), + } + go server.Run() + defer server.Close() + + conn, err := proxyConn(client, server) + if err != nil { + t.Error(err) + } + defer conn.Close() + + cc, ok := conn.(*http2ClientConn) + if !ok { + t.Error("wrong connection type") + } + + req := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Scheme: "https", Host: cc.addr}, + Header: make(http.Header), + Proto: "HTTP/2.0", + ProtoMajor: 2, + ProtoMinor: 0, + Body: ioutil.NopCloser(bytes.NewReader(sendData)), + Host: "github.com:443", + ContentLength: int64(len(sendData)), + } + + resp, err := cc.client.Do(req) + if err != nil { + t.Error(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Error("got non-200 status:", resp.Status) + } + + recv, _ := ioutil.ReadAll(resp.Body) + if !bytes.Equal(sendData, recv) { + t.Error("data not equal") + } +} + +func TestHTTP2ProxyWithFileProbeResist(t *testing.T) { + ln, err := HTTP2Listener("", nil) + if err != nil { + t.Error(err) + } + + client := &Client{ + Connector: HTTP2Connector(nil), + Transporter: HTTP2Transporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: HTTP2Handler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ProbeResistHandlerOption("file:.config/probe_resist.txt"), + ), + } + go server.Run() + defer server.Close() + + conn, err := proxyConn(client, server) + if err != nil { + t.Error(err) + } + defer conn.Close() + + conn, err = client.Connect(conn, "github.com:443") + if err != nil { + t.Error(err) + } + recv, _ := ioutil.ReadAll(conn) + if !bytes.Equal(recv, []byte("Hello World!")) { + t.Error("data not equal") + } +} + +func TestHTTP2ProxyWithBypass(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + u, err := url.Parse(httpSrv.URL) + if err != nil { + t.Error(err) + } + ln, err := HTTP2Listener("", nil) + if err != nil { + t.Error(err) + } + + client := &Client{ + Connector: HTTP2Connector(nil), + Transporter: HTTP2Transporter(nil), + } + + host := u.Host + if h, _, _ := net.SplitHostPort(u.Host); h != "" { + host = h + } + server := &Server{ + Listener: ln, + Handler: HTTP2Handler( + BypassHandlerOption(NewBypassPatterns(false, host)), + ), + } + go server.Run() + defer server.Close() + + if err = proxyRoundtrip(client, server, httpSrv.URL, sendData); err == nil { + t.Error("should failed") + } +} diff --git a/http_test.go b/http_test.go new file mode 100644 index 0000000..a4f1041 --- /dev/null +++ b/http_test.go @@ -0,0 +1,378 @@ +package gost + +import ( + "bytes" + "crypto/rand" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +var httpProxyTests = []struct { + cliUser *url.Userinfo + srvUsers []*url.Userinfo + errStr string +}{ + {nil, nil, ""}, + {nil, []*url.Userinfo{url.User("admin")}, "407 Proxy Authentication Required"}, + {nil, []*url.Userinfo{url.UserPassword("", "123456")}, "407 Proxy Authentication Required"}, + {url.User("admin"), []*url.Userinfo{url.User("test")}, "407 Proxy Authentication Required"}, + {url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "123456")}, "407 Proxy Authentication Required"}, + {url.User("admin"), []*url.Userinfo{url.User("admin")}, ""}, + {url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, ""}, + {url.UserPassword("admin", "123456"), nil, ""}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, ""}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, "407 Proxy Authentication Required"}, + {url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, ""}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, ""}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("user", "pass"), url.UserPassword("admin", "123456")}, ""}, +} + +func httpProxyRoundtrip(targetURL string, data []byte, clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + ln, err := TCPListener("") + if err != nil { + return err + } + + client := &Client{ + Connector: HTTPConnector(clientInfo), + Transporter: TCPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestHTTPProxyAuth(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range httpProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := httpProxyRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) + if err == nil { + if tc.errStr != "" { + t.Errorf("#%d should failed with error %s", i, tc.errStr) + } + } else { + if tc.errStr == "" { + t.Errorf("#%d got error %v", i, err) + } + if err.Error() != tc.errStr { + t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) + } + } + }) + } +} + +func TestHTTPProxyWithInvalidRequest(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + t.Error(err) + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler(), + } + go server.Run() + defer server.Close() + + r, err := http.NewRequest("GET", "http://"+ln.Addr().String(), bytes.NewReader(sendData)) + if err != nil { + t.Error(err) + } + resp, err := http.DefaultClient.Do(r) + if err != nil { + t.Error(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Error("got status:", resp.Status) + } +} + +func BenchmarkHTTPProxy(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: TCPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkHTTPProxyParallel(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: TCPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } + }) +} + +func TestHTTPProxyWithCodeProbeResist(t *testing.T) { + ln, err := TCPListener("") + if err != nil { + t.Error(err) + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ProbeResistHandlerOption("code:400"), + ), + } + go server.Run() + defer server.Close() + + resp, err := http.Get("http://" + ln.Addr().String()) + if err != nil { + t.Error(err) + } + defer resp.Body.Close() + + if resp.StatusCode != 400 { + t.Error("should failed with status code 400, got", resp.Status) + } +} + +func TestHTTPProxyWithWebProbeResist(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + ln, err := TCPListener("") + if err != nil { + t.Error(err) + } + + u, err := url.Parse(httpSrv.URL) + if err != nil { + t.Error(err) + } + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ProbeResistHandlerOption("web:"+u.Host), + ), + } + go server.Run() + defer server.Close() + + r, err := http.NewRequest("GET", "http://"+ln.Addr().String(), nil) + if err != nil { + t.Error(err) + } + resp, err := http.DefaultClient.Do(r) + if err != nil { + t.Error(err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + t.Error("got status:", resp.Status) + } + + recv, _ := ioutil.ReadAll(resp.Body) + if !bytes.Equal(recv, []byte("Hello World!")) { + t.Error("data not equal") + } +} + +func TestHTTPProxyWithHostProbeResist(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + t.Error(err) + } + + u, err := url.Parse(httpSrv.URL) + if err != nil { + t.Error(err) + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ProbeResistHandlerOption("host:"+u.Host), + ), + } + go server.Run() + defer server.Close() + + r, err := http.NewRequest("GET", "http://"+ln.Addr().String(), bytes.NewReader(sendData)) + if err != nil { + t.Error(err) + } + resp, err := http.DefaultClient.Do(r) + if err != nil { + t.Error(err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + t.Error("got status:", resp.Status) + } + + recv, _ := ioutil.ReadAll(resp.Body) + if !bytes.Equal(sendData, recv) { + t.Error("data not equal") + } +} + +func TestHTTPProxyWithFileProbeResist(t *testing.T) { + ln, err := TCPListener("") + if err != nil { + t.Error(err) + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ProbeResistHandlerOption("file:.config/probe_resist.txt"), + ), + } + go server.Run() + defer server.Close() + + r, err := http.NewRequest("GET", "http://"+ln.Addr().String(), nil) + if err != nil { + t.Error(err) + } + resp, err := http.DefaultClient.Do(r) + if err != nil { + t.Error(err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + t.Error("got status:", resp.Status) + } + + recv, _ := ioutil.ReadAll(resp.Body) + if !bytes.Equal(recv, []byte("Hello World!")) { + t.Error("data not equal, got:", string(recv)) + } +} + +func TestHTTPProxyWithBypass(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + u, err := url.Parse(httpSrv.URL) + if err != nil { + t.Error(err) + } + ln, err := TCPListener("") + if err != nil { + t.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(nil), + Transporter: TCPTransporter(), + } + + host := u.Host + if h, _, _ := net.SplitHostPort(u.Host); h != "" { + host = h + } + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + BypassHandlerOption(NewBypassPatterns(false, host)), + ), + } + go server.Run() + defer server.Close() + + if err = proxyRoundtrip(client, server, httpSrv.URL, sendData); err == nil { + t.Error("should failed") + } +} diff --git a/kcp.go b/kcp.go new file mode 100644 index 0000000..cbb7e9e --- /dev/null +++ b/kcp.go @@ -0,0 +1,503 @@ +package gost + +import ( + "crypto/sha1" + "encoding/csv" + "errors" + "fmt" + "net" + "os" + "time" + + "golang.org/x/crypto/pbkdf2" + + "sync" + + "github.com/go-log/log" + "github.com/klauspost/compress/snappy" + "github.com/xtaci/kcp-go" + "github.com/xtaci/smux" + "github.com/xtaci/tcpraw" +) + +var ( + // KCPSalt is the default salt for KCP cipher. + KCPSalt = "kcp-go" +) + +// KCPConfig describes the config for KCP. +type KCPConfig struct { + Key string `json:"key"` + Crypt string `json:"crypt"` + Mode string `json:"mode"` + MTU int `json:"mtu"` + SndWnd int `json:"sndwnd"` + RcvWnd int `json:"rcvwnd"` + DataShard int `json:"datashard"` + ParityShard int `json:"parityshard"` + DSCP int `json:"dscp"` + NoComp bool `json:"nocomp"` + AckNodelay bool `json:"acknodelay"` + NoDelay int `json:"nodelay"` + Interval int `json:"interval"` + Resend int `json:"resend"` + NoCongestion int `json:"nc"` + SockBuf int `json:"sockbuf"` + KeepAlive int `json:"keepalive"` + SnmpLog string `json:"snmplog"` + SnmpPeriod int `json:"snmpperiod"` + Signal bool `json:"signal"` // Signal enables the signal SIGUSR1 feature. + TCP bool `json:"tcp"` +} + +// Init initializes the KCP config. +func (c *KCPConfig) Init() { + switch c.Mode { + case "normal": + c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 0, 40, 2, 1 + case "fast": + c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 0, 30, 2, 1 + case "fast2": + c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 1, 20, 2, 1 + case "fast3": + c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 1, 10, 2, 1 + } +} + +var ( + // DefaultKCPConfig is the default KCP config. + DefaultKCPConfig = KCPConfig{ + Key: "it's a secrect", + Crypt: "aes", + Mode: "fast", + MTU: 1350, + SndWnd: 1024, + RcvWnd: 1024, + DataShard: 10, + ParityShard: 3, + DSCP: 0, + NoComp: false, + AckNodelay: false, + NoDelay: 0, + Interval: 50, + Resend: 0, + NoCongestion: 0, + SockBuf: 4194304, + KeepAlive: 10, + SnmpLog: "", + SnmpPeriod: 60, + Signal: false, + TCP: false, + } +) + +type kcpTransporter struct { + sessions map[string]*muxSession + sessionMutex sync.Mutex + config *KCPConfig +} + +// KCPTransporter creates a Transporter that is used by KCP proxy client. +func KCPTransporter(config *KCPConfig) Transporter { + if config == nil { + config = &KCPConfig{} + *config = DefaultKCPConfig + } + config.Init() + + go snmpLogger(config.SnmpLog, config.SnmpPeriod) + if config.Signal { + go kcpSigHandler() + } + + return &kcpTransporter{ + config: config, + sessions: make(map[string]*muxSession), + } +} + +func (tr *kcpTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { + opts := &DialOptions{} + for _, option := range options { + option(opts) + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + session, ok := tr.sessions[addr] + if session != nil && session.session != nil && session.session.IsClosed() { + session.Close() + delete(tr.sessions, addr) // session is dead + ok = false + } + if !ok { + raddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + if tr.config.TCP { + pc, err := tcpraw.Dial("tcp", addr) + if err != nil { + return nil, err + } + conn = &fakeTCPConn{ + raddr: raddr, + PacketConn: pc, + } + } else { + conn, err = net.ListenUDP("udp", nil) + if err != nil { + return nil, err + } + } + session = &muxSession{conn: conn} + tr.sessions[addr] = session + } + return session.conn, nil +} + +func (tr *kcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + config := tr.config + if opts.KCPConfig != nil { + config = opts.KCPConfig + } + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + timeout := opts.Timeout + if timeout <= 0 { + timeout = HandshakeTimeout + } + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + + session, ok := tr.sessions[opts.Addr] + if !ok || session.session == nil { + s, err := tr.initSession(opts.Addr, conn, config) + if err != nil { + conn.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + session = s + tr.sessions[opts.Addr] = session + } + cc, err := session.GetConn() + if err != nil { + session.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + + return cc, nil +} + +func (tr *kcpTransporter) initSession(addr string, conn net.Conn, config *KCPConfig) (*muxSession, error) { + pc, ok := conn.(net.PacketConn) + if !ok { + return nil, errors.New("kcp: wrong connection type") + } + + kcpconn, err := kcp.NewConn(addr, + blockCrypt(config.Key, config.Crypt, KCPSalt), + config.DataShard, config.ParityShard, pc) + if err != nil { + return nil, err + } + + kcpconn.SetStreamMode(true) + kcpconn.SetWriteDelay(false) + kcpconn.SetNoDelay(config.NoDelay, config.Interval, config.Resend, config.NoCongestion) + kcpconn.SetWindowSize(config.SndWnd, config.RcvWnd) + kcpconn.SetMtu(config.MTU) + kcpconn.SetACKNoDelay(config.AckNodelay) + + if config.DSCP > 0 { + if err := kcpconn.SetDSCP(config.DSCP); err != nil { + log.Log("[kcp]", err) + } + } + if err := kcpconn.SetReadBuffer(config.SockBuf); err != nil { + log.Log("[kcp]", err) + } + if err := kcpconn.SetWriteBuffer(config.SockBuf); err != nil { + log.Log("[kcp]", err) + } + + // stream multiplex + smuxConfig := smux.DefaultConfig() + smuxConfig.MaxReceiveBuffer = config.SockBuf + smuxConfig.KeepAliveInterval = time.Duration(config.KeepAlive) * time.Second + var cc net.Conn = kcpconn + if !config.NoComp { + cc = newCompStreamConn(kcpconn) + } + session, err := smux.Client(cc, smuxConfig) + if err != nil { + return nil, err + } + return &muxSession{conn: conn, session: session}, nil +} + +func (tr *kcpTransporter) Multiplex() bool { + return true +} + +type kcpListener struct { + config *KCPConfig + ln *kcp.Listener + connChan chan net.Conn + errChan chan error +} + +// KCPListener creates a Listener for KCP proxy server. +func KCPListener(addr string, config *KCPConfig) (Listener, error) { + if config == nil { + config = &KCPConfig{} + *config = DefaultKCPConfig + } + config.Init() + + var err error + var ln *kcp.Listener + if config.TCP { + var conn net.PacketConn + conn, err = tcpraw.Listen("tcp", addr) + if err != nil { + return nil, err + } + ln, err = kcp.ServeConn( + blockCrypt(config.Key, config.Crypt, KCPSalt), config.DataShard, config.ParityShard, conn) + if err != nil { + return nil, err + } + } else { + ln, err = kcp.ListenWithOptions(addr, + blockCrypt(config.Key, config.Crypt, KCPSalt), config.DataShard, config.ParityShard) + } + if err != nil { + return nil, err + } + if config.DSCP > 0 { + if err = ln.SetDSCP(config.DSCP); err != nil { + log.Log("[kcp]", err) + } + } + if err = ln.SetReadBuffer(config.SockBuf); err != nil { + log.Log("[kcp]", err) + } + if err = ln.SetWriteBuffer(config.SockBuf); err != nil { + log.Log("[kcp]", err) + } + + go snmpLogger(config.SnmpLog, config.SnmpPeriod) + if config.Signal { + go kcpSigHandler() + } + + l := &kcpListener{ + config: config, + ln: ln, + connChan: make(chan net.Conn, 1024), + errChan: make(chan error, 1), + } + go l.listenLoop() + + return l, nil +} + +func (l *kcpListener) listenLoop() { + for { + conn, err := l.ln.AcceptKCP() + if err != nil { + log.Log("[kcp] accept:", err) + l.errChan <- err + close(l.errChan) + return + } + conn.SetStreamMode(true) + conn.SetWriteDelay(false) + conn.SetNoDelay(l.config.NoDelay, l.config.Interval, l.config.Resend, l.config.NoCongestion) + conn.SetMtu(l.config.MTU) + conn.SetWindowSize(l.config.SndWnd, l.config.RcvWnd) + conn.SetACKNoDelay(l.config.AckNodelay) + go l.mux(conn) + } +} + +func (l *kcpListener) mux(conn net.Conn) { + smuxConfig := smux.DefaultConfig() + smuxConfig.MaxReceiveBuffer = l.config.SockBuf + smuxConfig.KeepAliveInterval = time.Duration(l.config.KeepAlive) * time.Second + + log.Logf("[kcp] %s - %s", conn.RemoteAddr(), l.Addr()) + + if !l.config.NoComp { + conn = newCompStreamConn(conn) + } + + mux, err := smux.Server(conn, smuxConfig) + if err != nil { + log.Log("[kcp]", err) + return + } + defer mux.Close() + + log.Logf("[kcp] %s <-> %s", conn.RemoteAddr(), l.Addr()) + defer log.Logf("[kcp] %s >-< %s", conn.RemoteAddr(), l.Addr()) + + for { + stream, err := mux.AcceptStream() + if err != nil { + log.Log("[kcp] accept stream:", err) + return + } + + cc := &muxStreamConn{Conn: conn, stream: stream} + select { + case l.connChan <- cc: + default: + cc.Close() + log.Logf("[kcp] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr()) + } + } +} + +func (l *kcpListener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.connChan: + case err, ok = <-l.errChan: + if !ok { + err = errors.New("accpet on closed listener") + } + } + return +} +func (l *kcpListener) Addr() net.Addr { + return l.ln.Addr() +} + +func (l *kcpListener) Close() error { + return l.ln.Close() +} + +func blockCrypt(key, crypt, salt string) (block kcp.BlockCrypt) { + pass := pbkdf2.Key([]byte(key), []byte(salt), 4096, 32, sha1.New) + + switch crypt { + case "sm4": + block, _ = kcp.NewSM4BlockCrypt(pass[:16]) + case "tea": + block, _ = kcp.NewTEABlockCrypt(pass[:16]) + case "xor": + block, _ = kcp.NewSimpleXORBlockCrypt(pass) + case "none": + block, _ = kcp.NewNoneBlockCrypt(pass) + case "aes-128": + block, _ = kcp.NewAESBlockCrypt(pass[:16]) + case "aes-192": + block, _ = kcp.NewAESBlockCrypt(pass[:24]) + case "blowfish": + block, _ = kcp.NewBlowfishBlockCrypt(pass) + case "twofish": + block, _ = kcp.NewTwofishBlockCrypt(pass) + case "cast5": + block, _ = kcp.NewCast5BlockCrypt(pass[:16]) + case "3des": + block, _ = kcp.NewTripleDESBlockCrypt(pass[:24]) + case "xtea": + block, _ = kcp.NewXTEABlockCrypt(pass[:16]) + case "salsa20": + block, _ = kcp.NewSalsa20BlockCrypt(pass) + case "aes": + fallthrough + default: // aes + block, _ = kcp.NewAESBlockCrypt(pass) + } + return +} + +func snmpLogger(format string, interval int) { + if format == "" || interval == 0 { + return + } + ticker := time.NewTicker(time.Duration(interval) * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + f, err := os.OpenFile(time.Now().Format(format), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) + if err != nil { + log.Log("[kcp]", err) + return + } + w := csv.NewWriter(f) + // write header in empty file + if stat, err := f.Stat(); err == nil && stat.Size() == 0 { + if err := w.Write(append([]string{"Unix"}, kcp.DefaultSnmp.Header()...)); err != nil { + log.Log("[kcp]", err) + } + } + if err := w.Write(append([]string{fmt.Sprint(time.Now().Unix())}, kcp.DefaultSnmp.ToSlice()...)); err != nil { + log.Log("[kcp]", err) + } + kcp.DefaultSnmp.Reset() + w.Flush() + f.Close() + } + } +} + +type compStreamConn struct { + conn net.Conn + w *snappy.Writer + r *snappy.Reader +} + +func newCompStreamConn(conn net.Conn) *compStreamConn { + c := new(compStreamConn) + c.conn = conn + c.w = snappy.NewBufferedWriter(conn) + c.r = snappy.NewReader(conn) + return c +} + +func (c *compStreamConn) Read(b []byte) (n int, err error) { + return c.r.Read(b) +} + +func (c *compStreamConn) Write(b []byte) (n int, err error) { + n, err = c.w.Write(b) + err = c.w.Flush() + return n, err +} + +func (c *compStreamConn) Close() error { + return c.conn.Close() +} + +func (c *compStreamConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *compStreamConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *compStreamConn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *compStreamConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *compStreamConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} diff --git a/kcp_test.go b/kcp_test.go new file mode 100644 index 0000000..d9e6ba0 --- /dev/null +++ b/kcp_test.go @@ -0,0 +1,408 @@ +package gost + +import ( + "crypto/rand" + "fmt" + "net/http/httptest" + "net/url" + "testing" +) + +func httpOverKCPRoundtrip(targetURL string, data []byte, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := KCPListener("localhost:0", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: HTTPConnector(clientInfo), + Transporter: KCPTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestHTTPOverKCP(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range httpProxyTests { + err := httpOverKCPRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) + if err == nil { + if tc.errStr != "" { + t.Errorf("#%d should failed with error %s", i, tc.errStr) + } + } else { + if tc.errStr == "" { + t.Errorf("#%d got error %v", i, err) + } + if err.Error() != tc.errStr { + t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) + } + } + } +} + +func BenchmarkHTTPOverKCP(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := KCPListener("", nil) + if err != nil { + b.Error(err) + } + b.Log(ln.Addr()) + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: KCPTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkHTTPOverKCPParallel(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := KCPListener("", nil) + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: KCPTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } + }) +} + +func socks5OverKCPRoundtrip(targetURL string, data []byte, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := KCPListener("localhost:0", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS5Connector(clientInfo), + Transporter: KCPTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS5Handler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS5OverKCP(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range socks5ProxyTests { + err := socks5OverKCPRoundtrip(httpSrv.URL, sendData, + tc.cliUser, + tc.srvUsers, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func socks4OverKCPRoundtrip(targetURL string, data []byte) error { + ln, err := KCPListener("localhost:0", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4Connector(), + Transporter: KCPTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4OverKCP(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4OverKCPRoundtrip(httpSrv.URL, sendData) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func socks4aOverKCPRoundtrip(targetURL string, data []byte) error { + ln, err := KCPListener("localhost:0", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4AConnector(), + Transporter: KCPTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4AOverKCP(t *testing.T) { + + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4aOverKCPRoundtrip(httpSrv.URL, sendData) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func ssOverKCPRoundtrip(targetURL string, data []byte, + clientInfo, serverInfo *url.Userinfo) error { + + ln, err := KCPListener("localhost:0", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: ShadowConnector(clientInfo), + Transporter: KCPTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: ShadowHandler( + UsersHandlerOption(serverInfo), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSSOverKCP(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range ssProxyTests { + err := ssOverKCPRoundtrip(httpSrv.URL, sendData, + tc.clientCipher, + tc.serverCipher, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func sniOverKCPRoundtrip(targetURL string, data []byte, host string) error { + ln, err := KCPListener("localhost:0", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: KCPTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverKCP(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverKCPRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} + +func kcpForwardTunnelRoundtrip(targetURL string, data []byte) error { + ln, err := KCPListener("localhost:0", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: KCPTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + server.Handler.Init() + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestKCPForwardTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := kcpForwardTunnelRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} diff --git a/log.go b/log.go new file mode 100644 index 0000000..d4ad519 --- /dev/null +++ b/log.go @@ -0,0 +1,36 @@ +package gost + +import ( + "fmt" + "log" +) + +func init() { + log.SetFlags(log.LstdFlags | log.Lshortfile) +} + +// LogLogger uses the standard log package as the logger +type LogLogger struct { +} + +// Log uses the standard log library log.Output +func (l *LogLogger) Log(v ...interface{}) { + log.Output(3, fmt.Sprintln(v...)) +} + +// Logf uses the standard log library log.Output +func (l *LogLogger) Logf(format string, v ...interface{}) { + log.Output(3, fmt.Sprintf(format, v...)) +} + +// NopLogger is a dummy logger that discards the log outputs +type NopLogger struct { +} + +// Log does nothing +func (l *NopLogger) Log(v ...interface{}) { +} + +// Logf does nothing +func (l *NopLogger) Logf(format string, v ...interface{}) { +} diff --git a/mux.go b/mux.go new file mode 100644 index 0000000..a378452 --- /dev/null +++ b/mux.go @@ -0,0 +1,63 @@ +package gost + +import ( + "net" + + smux "github.com/xtaci/smux" +) + +type muxStreamConn struct { + net.Conn + stream *smux.Stream +} + +func (c *muxStreamConn) Read(b []byte) (n int, err error) { + return c.stream.Read(b) +} + +func (c *muxStreamConn) Write(b []byte) (n int, err error) { + return c.stream.Write(b) +} + +func (c *muxStreamConn) Close() error { + return c.stream.Close() +} + +type muxSession struct { + conn net.Conn + session *smux.Session +} + +func (session *muxSession) GetConn() (net.Conn, error) { + stream, err := session.session.OpenStream() + if err != nil { + return nil, err + } + return &muxStreamConn{Conn: session.conn, stream: stream}, nil +} + +func (session *muxSession) Accept() (net.Conn, error) { + stream, err := session.session.AcceptStream() + if err != nil { + return nil, err + } + return &muxStreamConn{Conn: session.conn, stream: stream}, nil +} + +func (session *muxSession) Close() error { + if session.session == nil { + return nil + } + return session.session.Close() +} + +func (session *muxSession) IsClosed() bool { + if session.session == nil { + return true + } + return session.session.IsClosed() +} + +func (session *muxSession) NumStreams() int { + return session.session.NumStreams() +} diff --git a/node.go b/node.go new file mode 100644 index 0000000..f64afc4 --- /dev/null +++ b/node.go @@ -0,0 +1,284 @@ +package gost + +import ( + "errors" + "fmt" + "net/url" + "strconv" + "strings" + "sync" + "time" +) + +var ( + // ErrInvalidNode is an error that implies the node is invalid. + ErrInvalidNode = errors.New("invalid node") +) + +// Node is a proxy node, mainly used to construct a proxy chain. +type Node struct { + ID int + Addr string + Host string + Protocol string + Transport string + Remote string // remote address, used by tcp/udp port forwarding + url *url.URL // raw url + User *url.Userinfo + Values url.Values + DialOptions []DialOption + HandshakeOptions []HandshakeOption + ConnectOptions []ConnectOption + Client *Client + marker *failMarker + Bypass *Bypass +} + +// ParseNode parses the node info. +// The proxy node string pattern is [scheme://][user:pass@host]:port. +// Scheme can be divided into two parts by character '+', such as: http+tls. +func ParseNode(s string) (node Node, err error) { + s = strings.TrimSpace(s) + if s == "" { + return Node{}, ErrInvalidNode + } + + if !strings.Contains(s, "://") { + s = "auto://" + s + } + u, err := url.Parse(s) + if err != nil { + return + } + + node = Node{ + Addr: u.Host, + Host: u.Host, + Remote: strings.Trim(u.EscapedPath(), "/"), + Values: u.Query(), + User: u.User, + marker: &failMarker{}, + url: u, + } + + u.RawQuery = "" + u.User = nil + + schemes := strings.Split(u.Scheme, "+") + if len(schemes) == 1 { + node.Protocol = schemes[0] + node.Transport = schemes[0] + } + if len(schemes) == 2 { + node.Protocol = schemes[0] + node.Transport = schemes[1] + } + + switch node.Transport { + case "https": + node.Transport = "tls" + case "tls", "mtls": + case "http2", "h2", "h2c": + case "ws", "mws", "wss", "mwss": + case "kcp", "ssh", "quic": + case "ssu": + node.Transport = "udp" + case "ohttp", "otls", "obfs4": // obfs + case "tcp", "udp": + case "rtcp", "rudp": // rtcp and rudp are for remote port forwarding + case "tun", "tap": // tun/tap device + case "ftcp": // fake TCP + case "dns": + case "redu", "redirectu": // UDP tproxy + default: + node.Transport = "tcp" + } + + switch node.Protocol { + case "http", "http2": + case "https": + node.Protocol = "http" + case "socks4", "socks4a": + case "socks", "socks5": + node.Protocol = "socks5" + case "ss", "ssu": + case "ss2": // as of 2.10.1, ss2 is same as ss + node.Protocol = "ss" + case "sni": + case "tcp", "udp", "rtcp", "rudp": // port forwarding + case "direct", "remote", "forward": // forwarding + case "red", "redirect", "redu", "redirectu": // TCP,UDP transparent proxy + case "tun", "tap": // tun/tap device + case "ftcp": // fake TCP + case "dns", "dot", "doh": + case "relay": + default: + node.Protocol = "" + } + + return +} + +// MarkDead marks the node fail status. +func (node *Node) MarkDead() { + if node.marker == nil { + return + } + node.marker.Mark() +} + +// ResetDead resets the node fail status. +func (node *Node) ResetDead() { + if node.marker == nil { + return + } + node.marker.Reset() +} + +// Clone clones the node, it will prevent data race. +func (node *Node) Clone() Node { + nd := *node + if node.marker != nil { + nd.marker = node.marker.Clone() + } + return nd +} + +// Get returns node parameter specified by key. +func (node *Node) Get(key string) string { + return node.Values.Get(key) +} + +// GetBool converts node parameter value to bool. +func (node *Node) GetBool(key string) bool { + b, _ := strconv.ParseBool(node.Values.Get(key)) + return b +} + +// GetInt converts node parameter value to int. +func (node *Node) GetInt(key string) int { + n, _ := strconv.Atoi(node.Get(key)) + return n +} + +// GetDuration converts node parameter value to time.Duration. +func (node *Node) GetDuration(key string) time.Duration { + d, err := time.ParseDuration(node.Get(key)) + if err != nil { + d = time.Duration(node.GetInt(key)) * time.Second + } + return d +} + +func (node Node) String() string { + var scheme string + if node.url != nil { + scheme = node.url.Scheme + } + if scheme == "" { + scheme = fmt.Sprintf("%s+%s", node.Protocol, node.Transport) + } + return fmt.Sprintf("%s://%s", + scheme, node.Addr) +} + +// NodeGroup is a group of nodes. +type NodeGroup struct { + ID int + nodes []Node + selectorOptions []SelectOption + selector NodeSelector + mux sync.RWMutex +} + +// NewNodeGroup creates a node group +func NewNodeGroup(nodes ...Node) *NodeGroup { + return &NodeGroup{ + nodes: nodes, + } +} + +// AddNode appends node or node list into group node. +func (group *NodeGroup) AddNode(node ...Node) { + if group == nil { + return + } + group.mux.Lock() + defer group.mux.Unlock() + + group.nodes = append(group.nodes, node...) +} + +// SetNodes replaces the group nodes to the specified nodes, +// and returns the previous nodes. +func (group *NodeGroup) SetNodes(nodes ...Node) []Node { + if group == nil { + return nil + } + + group.mux.Lock() + defer group.mux.Unlock() + + old := group.nodes + group.nodes = nodes + return old +} + +// SetSelector sets node selector with options for the group. +func (group *NodeGroup) SetSelector(selector NodeSelector, opts ...SelectOption) { + if group == nil { + return + } + group.mux.Lock() + defer group.mux.Unlock() + + group.selector = selector + group.selectorOptions = opts +} + +// Nodes returns the node list in the group +func (group *NodeGroup) Nodes() []Node { + if group == nil { + return nil + } + + group.mux.RLock() + defer group.mux.RUnlock() + + return group.nodes +} + +// GetNode returns the node specified by index in the group. +func (group *NodeGroup) GetNode(i int) Node { + group.mux.RLock() + defer group.mux.RUnlock() + + if i < 0 || group == nil || len(group.nodes) <= i { + return Node{} + } + return group.nodes[i] +} + +// Next selects a node from group. +// It also selects IP if the IP list exists. +func (group *NodeGroup) Next() (node Node, err error) { + if group == nil { + return + } + + group.mux.RLock() + defer group.mux.RUnlock() + + selector := group.selector + if selector == nil { + selector = &defaultSelector{} + } + + // select node from node group + node, err = selector.Select(group.nodes, group.selectorOptions...) + if err != nil { + return + } + + return +} diff --git a/node_test.go b/node_test.go new file mode 100644 index 0000000..31718f4 --- /dev/null +++ b/node_test.go @@ -0,0 +1,68 @@ +package gost + +import "testing" +import "net/url" + +var nodeTests = []struct { + in string + out Node + hasError bool +}{ + {"", Node{}, true}, + {"://", Node{}, true}, + {"localhost", Node{Addr: "localhost", Transport: "tcp"}, false}, + {":", Node{Addr: ":", Transport: "tcp"}, false}, + {":8080", Node{Addr: ":8080", Transport: "tcp"}, false}, + {"http://:8080", Node{Addr: ":8080", Protocol: "http", Transport: "tcp"}, false}, + {"http://localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp"}, false}, + {"http://admin:123456@:8080", Node{Addr: ":8080", Protocol: "http", Transport: "tcp", User: url.UserPassword("admin", "123456")}, false}, + {"http://admin@localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp", User: url.User("admin")}, false}, + {"http://:123456@localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp", User: url.UserPassword("", "123456")}, false}, + {"http://@localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp", User: url.User("")}, false}, + {"http://:@localhost:8080", Node{Addr: "localhost:8080", Protocol: "http", Transport: "tcp", User: url.UserPassword("", "")}, false}, + {"https://:8080", Node{Addr: ":8080", Protocol: "http", Transport: "tls"}, false}, + {"socks+tls://:8080", Node{Addr: ":8080", Protocol: "socks5", Transport: "tls"}, false}, + {"tls://:8080", Node{Addr: ":8080", Transport: "tls"}, false}, + {"tcp://:8080/:8081", Node{Addr: ":8080", Remote: ":8081", Protocol: "tcp", Transport: "tcp"}, false}, + {"udp://:8080/:8081", Node{Addr: ":8080", Remote: ":8081", Protocol: "udp", Transport: "udp"}, false}, + {"rtcp://:8080/:8081", Node{Addr: ":8080", Remote: ":8081", Protocol: "rtcp", Transport: "rtcp"}, false}, + {"rudp://:8080/:8081", Node{Addr: ":8080", Remote: ":8081", Protocol: "rudp", Transport: "rudp"}, false}, + {"redirect://:8080", Node{Addr: ":8080", Protocol: "redirect", Transport: "tcp"}, false}, +} + +func TestParseNode(t *testing.T) { + for _, test := range nodeTests { + actual, err := ParseNode(test.in) + if err != nil { + if test.hasError { + // t.Logf("ParseNode(%q) got expected error: %v", test.in, err) + continue + } + t.Errorf("ParseNode(%q) got error: %v", test.in, err) + } else { + if test.hasError { + t.Errorf("ParseNode(%q) got %v, but should return error", test.in, actual) + continue + } + if actual.Addr != test.out.Addr || actual.Protocol != test.out.Protocol || + actual.Transport != test.out.Transport || actual.Remote != test.out.Remote { + t.Errorf("ParseNode(%q) got %v, want %v", test.in, actual, test.out) + } + if actual.User == nil { + if test.out.User != nil { + t.Errorf("ParseNode(%q) got %v, want %v", test.in, actual, test.out) + } + continue + } + if actual.User != nil { + if test.out.User == nil { + t.Errorf("ParseNode(%q) got %v, want %v", test.in, actual, test.out) + continue + } + if *actual.User != *test.out.User { + t.Errorf("ParseNode(%q) got %v, want %v", test.in, actual, test.out) + } + } + } + } +} diff --git a/obfs.go b/obfs.go new file mode 100644 index 0000000..ffc161b --- /dev/null +++ b/obfs.go @@ -0,0 +1,818 @@ +// obfs4 connection wrappers + +package gost + +import ( + "bufio" + "bytes" + "crypto/rand" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httputil" + "net/url" + "sync" + "time" + + "github.com/go-log/log" + + pt "git.torproject.org/pluggable-transports/goptlib.git" + dissector "github.com/go-gost/tls-dissector" + "gitlab.com/yawning/obfs4.git/transports/base" + "gitlab.com/yawning/obfs4.git/transports/obfs4" +) + +const ( + maxTLSDataLen = 16384 +) + +type obfsHTTPTransporter struct { + tcpTransporter +} + +// ObfsHTTPTransporter creates a Transporter that is used by HTTP obfuscating tunnel client. +func ObfsHTTPTransporter() Transporter { + return &obfsHTTPTransporter{} +} + +func (tr *obfsHTTPTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + return &obfsHTTPConn{Conn: conn, host: opts.Host}, nil +} + +type obfsHTTPListener struct { + net.Listener +} + +// ObfsHTTPListener creates a Listener for HTTP obfuscating tunnel server. +func ObfsHTTPListener(addr string) (Listener, error) { + laddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + ln, err := net.ListenTCP("tcp", laddr) + if err != nil { + return nil, err + } + return &obfsHTTPListener{Listener: tcpKeepAliveListener{ln}}, nil +} + +func (l *obfsHTTPListener) Accept() (net.Conn, error) { + conn, err := l.Listener.Accept() + if err != nil { + return nil, err + } + + return &obfsHTTPConn{Conn: conn, isServer: true}, nil +} + +type obfsHTTPConn struct { + net.Conn + host string + rbuf bytes.Buffer + wbuf bytes.Buffer + isServer bool + headerDrained bool + handshaked bool + handshakeMutex sync.Mutex +} + +func (c *obfsHTTPConn) Handshake() (err error) { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + if c.handshaked { + return nil + } + + if c.isServer { + err = c.serverHandshake() + } else { + err = c.clientHandshake() + } + if err != nil { + return + } + + c.handshaked = true + return nil +} + +func (c *obfsHTTPConn) serverHandshake() (err error) { + br := bufio.NewReader(c.Conn) + r, err := http.ReadRequest(br) + if err != nil { + return + } + if Debug { + dump, _ := httputil.DumpRequest(r, false) + log.Logf("[ohttp] %s -> %s\n%s", c.RemoteAddr(), c.LocalAddr(), string(dump)) + } + + if r.ContentLength > 0 { + _, err = io.Copy(&c.rbuf, r.Body) + } else { + var b []byte + b, err = br.Peek(br.Buffered()) + if len(b) > 0 { + _, err = c.rbuf.Write(b) + } + } + if err != nil { + log.Logf("[ohttp] %s -> %s : %v", c.Conn.RemoteAddr(), c.Conn.LocalAddr(), err) + return + } + + b := bytes.Buffer{} + + if r.Method != http.MethodGet || r.Header.Get("Upgrade") != "websocket" { + b.WriteString("HTTP/1.1 503 Service Unavailable\r\n") + b.WriteString("Content-Length: 0\r\n") + b.WriteString("Date: " + time.Now().Format(time.RFC1123) + "\r\n") + b.WriteString("\r\n") + + if Debug { + log.Logf("[ohttp] %s <- %s\n%s", c.RemoteAddr(), c.LocalAddr(), b.String()) + } + + b.WriteTo(c.Conn) + return errors.New("bad request") + } + + b.WriteString("HTTP/1.1 101 Switching Protocols\r\n") + b.WriteString("Server: nginx/1.10.0\r\n") + b.WriteString("Date: " + time.Now().Format(time.RFC1123) + "\r\n") + b.WriteString("Connection: Upgrade\r\n") + b.WriteString("Upgrade: websocket\r\n") + b.WriteString(fmt.Sprintf("Sec-WebSocket-Accept: %s\r\n", computeAcceptKey(r.Header.Get("Sec-WebSocket-Key")))) + b.WriteString("\r\n") + + if Debug { + log.Logf("[ohttp] %s <- %s\n%s", c.RemoteAddr(), c.LocalAddr(), b.String()) + } + + if c.rbuf.Len() > 0 { + c.wbuf = b // cache the response header if there are extra data in the request body. + return + } + + _, err = b.WriteTo(c.Conn) + return +} + +func (c *obfsHTTPConn) clientHandshake() (err error) { + r := &http.Request{ + Method: http.MethodGet, + ProtoMajor: 1, + ProtoMinor: 1, + URL: &url.URL{Scheme: "http", Host: c.host}, + Header: make(http.Header), + } + r.Header.Set("User-Agent", DefaultUserAgent) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + key, _ := generateChallengeKey() + r.Header.Set("Sec-WebSocket-Key", key) + + // cache the request header + if err = r.Write(&c.wbuf); err != nil { + return + } + + if Debug { + dump, _ := httputil.DumpRequest(r, false) + log.Logf("[ohttp] %s -> %s\n%s", c.LocalAddr(), c.RemoteAddr(), string(dump)) + } + + return nil +} + +func (c *obfsHTTPConn) Read(b []byte) (n int, err error) { + if err = c.Handshake(); err != nil { + return + } + + if !c.isServer { + if err = c.drainHeader(); err != nil { + return + } + } + + if c.rbuf.Len() > 0 { + return c.rbuf.Read(b) + } + return c.Conn.Read(b) +} + +func (c *obfsHTTPConn) drainHeader() (err error) { + if c.headerDrained { + return + } + c.headerDrained = true + + br := bufio.NewReader(c.Conn) + // drain and discard the response header + var line string + var buf bytes.Buffer + for { + line, err = br.ReadString('\n') + if err != nil { + return + } + buf.WriteString(line) + if line == "\r\n" { + break + } + } + + if Debug { + log.Logf("[ohttp] %s <- %s\n%s", c.LocalAddr(), c.RemoteAddr(), buf.String()) + } + // cache the extra data for next read. + var b []byte + b, err = br.Peek(br.Buffered()) + if len(b) > 0 { + _, err = c.rbuf.Write(b) + } + return +} + +func (c *obfsHTTPConn) Write(b []byte) (n int, err error) { + if err = c.Handshake(); err != nil { + return + } + if c.wbuf.Len() > 0 { + c.wbuf.Write(b) // append the data to the cached header + _, err = c.wbuf.WriteTo(c.Conn) + n = len(b) // exclude the header length + return + } + return c.Conn.Write(b) +} + +type obfsTLSTransporter struct { + tcpTransporter +} + +// ObfsTLSTransporter creates a Transporter that is used by TLS obfuscating. +func ObfsTLSTransporter() Transporter { + return &obfsTLSTransporter{} +} + +func (tr *obfsTLSTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + return ClientObfsTLSConn(conn, opts.Host), nil +} + +type obfsTLSListener struct { + net.Listener +} + +// ObfsTLSListener creates a Listener for TLS obfuscating server. +func ObfsTLSListener(addr string) (Listener, error) { + laddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + ln, err := net.ListenTCP("tcp", laddr) + if err != nil { + return nil, err + } + return &obfsTLSListener{Listener: tcpKeepAliveListener{ln}}, nil +} + +func (l *obfsTLSListener) Accept() (net.Conn, error) { + conn, err := l.Listener.Accept() + if err != nil { + return nil, err + } + + return ServerObfsTLSConn(conn, ""), nil +} + +var ( + cipherSuites = []uint16{ + 0xc02c, 0xc030, 0x009f, 0xcca9, 0xcca8, 0xccaa, 0xc02b, 0xc02f, + 0x009e, 0xc024, 0xc028, 0x006b, 0xc023, 0xc027, 0x0067, 0xc00a, + 0xc014, 0x0039, 0xc009, 0xc013, 0x0033, 0x009d, 0x009c, 0x003d, + 0x003c, 0x0035, 0x002f, 0x00ff, + } + + compressionMethods = []uint8{0x00} + + algorithms = []uint16{ + 0x0601, 0x0602, 0x0603, 0x0501, 0x0502, 0x0503, 0x0401, 0x0402, + 0x0403, 0x0301, 0x0302, 0x0303, 0x0201, 0x0202, 0x0203, + } + + tlsRecordTypes = []uint8{0x16, 0x14, 0x16, 0x17} + tlsVersionMinors = []uint8{0x01, 0x03, 0x03, 0x03} + + ErrBadType = errors.New("bad type") + ErrBadMajorVersion = errors.New("bad major version") + ErrBadMinorVersion = errors.New("bad minor version") + ErrMaxDataLen = errors.New("bad tls data len") +) + +const ( + tlsRecordStateType = iota + tlsRecordStateVersion0 + tlsRecordStateVersion1 + tlsRecordStateLength0 + tlsRecordStateLength1 + tlsRecordStateData +) + +type obfsTLSParser struct { + step uint8 + state uint8 + length uint16 +} + +type obfsTLSConn struct { + net.Conn + rbuf bytes.Buffer + wbuf bytes.Buffer + host string + isServer bool + handshaked chan struct{} + parser *obfsTLSParser + handshakeMutex sync.Mutex +} + +func (r *obfsTLSParser) Parse(b []byte) (int, error) { + i := 0 + last := 0 + length := len(b) + + for i < length { + ch := b[i] + switch r.state { + case tlsRecordStateType: + if tlsRecordTypes[r.step] != ch { + return 0, ErrBadType + } + r.state = tlsRecordStateVersion0 + i++ + case tlsRecordStateVersion0: + if ch != 0x03 { + return 0, ErrBadMajorVersion + } + r.state = tlsRecordStateVersion1 + i++ + case tlsRecordStateVersion1: + if ch != tlsVersionMinors[r.step] { + return 0, ErrBadMinorVersion + } + r.state = tlsRecordStateLength0 + i++ + case tlsRecordStateLength0: + r.length = uint16(ch) << 8 + r.state = tlsRecordStateLength1 + i++ + case tlsRecordStateLength1: + r.length |= uint16(ch) + if r.step == 0 { + r.length = 91 + } else if r.step == 1 { + r.length = 1 + } else if r.length > maxTLSDataLen { + return 0, ErrMaxDataLen + } + if r.length > 0 { + r.state = tlsRecordStateData + } else { + r.state = tlsRecordStateType + r.step++ + } + i++ + case tlsRecordStateData: + left := uint16(length - i) + if left > r.length { + left = r.length + } + if r.step >= 2 { + skip := i - last + copy(b[last:], b[i:length]) + length -= int(skip) + last += int(left) + i = last + } else { + i += int(left) + } + r.length -= left + if r.length == 0 { + if r.step < 3 { + r.step++ + } + r.state = tlsRecordStateType + } + } + } + + if last == 0 { + return 0, nil + } else if last < length { + length -= last + } + + return length, nil +} + +// ClientObfsTLSConn creates a connection for obfs-tls client. +func ClientObfsTLSConn(conn net.Conn, host string) net.Conn { + return &obfsTLSConn{ + Conn: conn, + host: host, + handshaked: make(chan struct{}), + parser: &obfsTLSParser{}, + } +} + +// ServerObfsTLSConn creates a connection for obfs-tls server. +func ServerObfsTLSConn(conn net.Conn, host string) net.Conn { + return &obfsTLSConn{ + Conn: conn, + host: host, + isServer: true, + handshaked: make(chan struct{}), + } +} + +func (c *obfsTLSConn) Handshaked() bool { + select { + case <-c.handshaked: + return true + default: + return false + } +} + +func (c *obfsTLSConn) Handshake(payload []byte) (err error) { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + if c.Handshaked() { + return + } + + if c.isServer { + err = c.serverHandshake() + } else { + err = c.clientHandshake(payload) + } + if err != nil { + return + } + + close(c.handshaked) + return nil +} + +func (c *obfsTLSConn) clientHandshake(payload []byte) error { + clientMsg := &dissector.ClientHelloMsg{ + Version: tls.VersionTLS12, + SessionID: make([]byte, 32), + CipherSuites: cipherSuites, + CompressionMethods: compressionMethods, + Extensions: []dissector.Extension{ + &dissector.SessionTicketExtension{ + Data: payload, + }, + &dissector.ServerNameExtension{ + Name: c.host, + }, + &dissector.ECPointFormatsExtension{ + Formats: []uint8{0x01, 0x00, 0x02}, + }, + &dissector.SupportedGroupsExtension{ + Groups: []uint16{0x001d, 0x0017, 0x0019, 0x0018}, + }, + &dissector.SignatureAlgorithmsExtension{ + Algorithms: algorithms, + }, + &dissector.EncryptThenMacExtension{}, + &dissector.ExtendedMasterSecretExtension{}, + }, + } + clientMsg.Random.Time = uint32(time.Now().Unix()) + rand.Read(clientMsg.Random.Opaque[:]) + rand.Read(clientMsg.SessionID) + b, err := clientMsg.Encode() + if err != nil { + return err + } + + record := &dissector.Record{ + Type: dissector.Handshake, + Version: tls.VersionTLS10, + Opaque: b, + } + if _, err := record.WriteTo(c.Conn); err != nil { + return err + } + return err +} + +func (c *obfsTLSConn) serverHandshake() error { + record := &dissector.Record{} + if _, err := record.ReadFrom(c.Conn); err != nil { + log.Log(err) + return err + } + if record.Type != dissector.Handshake { + return dissector.ErrBadType + } + + clientMsg := &dissector.ClientHelloMsg{} + if err := clientMsg.Decode(record.Opaque); err != nil { + log.Log(err) + return err + } + + for _, ext := range clientMsg.Extensions { + if ext.Type() == dissector.ExtSessionTicket { + b, err := ext.Encode() + if err != nil { + log.Log(err) + return err + } + c.rbuf.Write(b) + break + } + } + + serverMsg := &dissector.ServerHelloMsg{ + Version: tls.VersionTLS12, + SessionID: clientMsg.SessionID, + CipherSuite: 0xcca8, + CompressionMethod: 0x00, + Extensions: []dissector.Extension{ + &dissector.RenegotiationInfoExtension{}, + &dissector.ExtendedMasterSecretExtension{}, + &dissector.ECPointFormatsExtension{ + Formats: []uint8{0x00}, + }, + }, + } + + serverMsg.Random.Time = uint32(time.Now().Unix()) + rand.Read(serverMsg.Random.Opaque[:]) + b, err := serverMsg.Encode() + if err != nil { + return err + } + + record = &dissector.Record{ + Type: dissector.Handshake, + Version: tls.VersionTLS10, + Opaque: b, + } + + if _, err := record.WriteTo(&c.wbuf); err != nil { + return err + } + + record = &dissector.Record{ + Type: dissector.ChangeCipherSpec, + Version: tls.VersionTLS12, + Opaque: []byte{0x01}, + } + if _, err := record.WriteTo(&c.wbuf); err != nil { + return err + } + return nil +} + +func (c *obfsTLSConn) Read(b []byte) (n int, err error) { + if c.isServer { // NOTE: only Write performs the handshake operation on client side. + if err = c.Handshake(nil); err != nil { + return + } + } + + select { + case <-c.handshaked: + } + + if c.isServer { + if c.rbuf.Len() > 0 { + return c.rbuf.Read(b) + } + record := &dissector.Record{} + if _, err = record.ReadFrom(c.Conn); err != nil { + return + } + n = copy(b, record.Opaque) + _, err = c.rbuf.Write(record.Opaque[n:]) + } else { + n, err = c.Conn.Read(b) + if err != nil { + return + } + if n > 0 { + n, err = c.parser.Parse(b[:n]) + } + } + return +} + +func (c *obfsTLSConn) Write(b []byte) (n int, err error) { + n = len(b) + if !c.Handshaked() { + if err = c.Handshake(b); err != nil { + return + } + if !c.isServer { // the data b has been sended during handshake phase. + return + } + } + + for len(b) > 0 { + data := b + if len(b) > maxTLSDataLen { + data = b[:maxTLSDataLen] + b = b[maxTLSDataLen:] + } else { + b = b[:0] + } + record := &dissector.Record{ + Type: dissector.AppData, + Version: tls.VersionTLS12, + Opaque: data, + } + + if c.wbuf.Len() > 0 { + record.Type = dissector.Handshake + record.WriteTo(&c.wbuf) + _, err = c.wbuf.WriteTo(c.Conn) + return + } + + if _, err = record.WriteTo(c.Conn); err != nil { + return + } + } + return +} + +type obfs4Context struct { + cf base.ClientFactory + cargs interface{} // type obfs4ClientArgs + sf base.ServerFactory + sargs *pt.Args +} + +var obfs4Map = make(map[string]obfs4Context) + +// Obfs4Init initializes the obfs client or server based on isServeNode +func Obfs4Init(node Node, isServeNode bool) error { + if _, ok := obfs4Map[node.Addr]; ok { + return fmt.Errorf("obfs4 context already inited") + } + + t := new(obfs4.Transport) + + stateDir := node.Values.Get("state-dir") + if stateDir == "" { + stateDir = "." + } + + ptArgs := pt.Args(node.Values) + + if !isServeNode { + cf, err := t.ClientFactory(stateDir) + if err != nil { + return err + } + + cargs, err := cf.ParseArgs(&ptArgs) + if err != nil { + return err + } + + obfs4Map[node.Addr] = obfs4Context{cf: cf, cargs: cargs} + } else { + sf, err := t.ServerFactory(stateDir, &ptArgs) + if err != nil { + return err + } + + sargs := sf.Args() + + obfs4Map[node.Addr] = obfs4Context{sf: sf, sargs: sargs} + + log.Log("[obfs4] server inited:", obfs4ServerURL(node)) + } + + return nil +} + +func obfs4GetContext(addr string) (obfs4Context, error) { + ctx, ok := obfs4Map[addr] + if !ok { + return obfs4Context{}, fmt.Errorf("obfs4 context not inited") + } + return ctx, nil +} + +func obfs4ServerURL(node Node) string { + ctx, err := obfs4GetContext(node.Addr) + if err != nil { + return "" + } + + values := (*url.Values)(ctx.sargs) + query := values.Encode() + return fmt.Sprintf( + "%s+%s://%s/?%s", //obfs4-cert=%s&iat-mode=%s", + node.Protocol, + node.Transport, + node.Addr, + query, + ) +} + +func obfs4ClientConn(addr string, conn net.Conn) (net.Conn, error) { + ctx, err := obfs4GetContext(addr) + if err != nil { + return nil, err + } + + pseudoDial := func(a, b string) (net.Conn, error) { return conn, nil } + return ctx.cf.Dial("tcp", "", pseudoDial, ctx.cargs) +} + +func obfs4ServerConn(addr string, conn net.Conn) (net.Conn, error) { + ctx, err := obfs4GetContext(addr) + if err != nil { + return nil, err + } + + return ctx.sf.WrapConn(conn) +} + +type obfs4Transporter struct { + tcpTransporter +} + +// Obfs4Transporter creates a Transporter that is used by obfs4 client. +func Obfs4Transporter() Transporter { + return &obfs4Transporter{} +} + +func (tr *obfs4Transporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = HandshakeTimeout + } + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + + return obfs4ClientConn(opts.Addr, conn) +} + +type obfs4Listener struct { + addr string + net.Listener +} + +// Obfs4Listener creates a Listener for obfs4 server. +func Obfs4Listener(addr string) (Listener, error) { + ln, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + l := &obfs4Listener{ + addr: addr, + Listener: tcpKeepAliveListener{ln.(*net.TCPListener)}, + } + return l, nil +} + +func (l *obfs4Listener) Accept() (net.Conn, error) { + conn, err := l.Listener.Accept() + if err != nil { + return nil, err + } + cc, err := obfs4ServerConn(l.addr, conn) + if err != nil { + conn.Close() + return nil, err + } + return cc, nil +} diff --git a/obfs_test.go b/obfs_test.go new file mode 100644 index 0000000..cd33702 --- /dev/null +++ b/obfs_test.go @@ -0,0 +1,424 @@ +package gost + +import ( + "crypto/rand" + "fmt" + "net/http/httptest" + "net/url" + "testing" +) + +func httpOverObfsHTTPRoundtrip(targetURL string, data []byte, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := ObfsHTTPListener("") + if err != nil { + return err + } + + client := &Client{ + Connector: HTTPConnector(clientInfo), + Transporter: ObfsHTTPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestHTTPOverObfsHTTP(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range httpProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := httpOverObfsHTTPRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) + if err == nil { + if tc.errStr != "" { + t.Errorf("#%d should failed with error %s", i, tc.errStr) + } + } else { + if tc.errStr == "" { + t.Errorf("#%d got error %v", i, err) + } + if err.Error() != tc.errStr { + t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) + } + } + }) + } +} + +func BenchmarkHTTPOverObfsHTTP(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := ObfsHTTPListener("") + if err != nil { + b.Error(err) + } + // b.Log(ln.Addr()) + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: ObfsHTTPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkHTTPOverObfsHTTPParallel(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := ObfsHTTPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: ObfsHTTPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } + }) +} + +func socks5OverObfsHTTPRoundtrip(targetURL string, data []byte, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := ObfsHTTPListener("") + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS5Connector(clientInfo), + Transporter: ObfsHTTPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS5Handler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS5OverObfsHTTP(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range socks5ProxyTests { + err := socks5OverObfsHTTPRoundtrip(httpSrv.URL, sendData, + tc.cliUser, + tc.srvUsers, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func socks4OverObfsHTTPRoundtrip(targetURL string, data []byte) error { + ln, err := ObfsHTTPListener("") + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4Connector(), + Transporter: ObfsHTTPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4OverObfsHTTP(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4OverObfsHTTPRoundtrip(httpSrv.URL, sendData) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func socks4aOverObfsHTTPRoundtrip(targetURL string, data []byte) error { + ln, err := ObfsHTTPListener("") + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4AConnector(), + Transporter: ObfsHTTPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4AOverObfsHTTP(t *testing.T) { + + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4aOverObfsHTTPRoundtrip(httpSrv.URL, sendData) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func ssOverObfsHTTPRoundtrip(targetURL string, data []byte, + clientInfo, serverInfo *url.Userinfo) error { + + ln, err := ObfsHTTPListener("") + if err != nil { + return err + } + + client := &Client{ + Connector: ShadowConnector(clientInfo), + Transporter: ObfsHTTPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: ShadowHandler( + UsersHandlerOption(serverInfo), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSSOverObfsHTTP(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range ssProxyTests { + err := ssOverObfsHTTPRoundtrip(httpSrv.URL, sendData, + tc.clientCipher, + tc.serverCipher, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func sniOverObfsHTTPRoundtrip(targetURL string, data []byte, host string) error { + ln, err := ObfsHTTPListener("") + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: ObfsHTTPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverObfsHTTP(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverObfsHTTPRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} + +func httpOverObfs4Roundtrip(targetURL string, data []byte, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := Obfs4Listener("") + if err != nil { + return err + } + + client := &Client{ + Connector: HTTPConnector(clientInfo), + Transporter: Obfs4Transporter(), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func _TestHTTPOverObfs4(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range httpProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := httpOverObfs4Roundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) + if err == nil { + if tc.errStr != "" { + t.Errorf("#%d should failed with error %s", i, tc.errStr) + } + } else { + if tc.errStr == "" { + t.Errorf("#%d got error %v", i, err) + } + if err.Error() != tc.errStr { + t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) + } + } + }) + } +} diff --git a/permissions.go b/permissions.go new file mode 100644 index 0000000..a2943fa --- /dev/null +++ b/permissions.go @@ -0,0 +1,223 @@ +package gost + +import ( + "errors" + "fmt" + "net" + "strconv" + "strings" + + glob "github.com/ryanuber/go-glob" +) + +// Permission is a rule for blacklist and whitelist. +type Permission struct { + Actions StringSet + Hosts StringSet + Ports PortSet +} + +// PortRange specifies the range of port, such as 1000-2000. +type PortRange struct { + Min, Max int +} + +// ParsePortRange parses the s to a PortRange. +// The s may be a '*' means 0-65535. +func ParsePortRange(s string) (*PortRange, error) { + if s == "*" { + return &PortRange{Min: 0, Max: 65535}, nil + } + + minmax := strings.Split(s, "-") + switch len(minmax) { + case 1: + port, err := strconv.Atoi(s) + if err != nil { + return nil, err + } + if port < 0 || port > 65535 { + return nil, fmt.Errorf("invalid port: %s", s) + } + return &PortRange{Min: port, Max: port}, nil + case 2: + min, err := strconv.Atoi(minmax[0]) + if err != nil { + return nil, err + } + max, err := strconv.Atoi(minmax[1]) + if err != nil { + return nil, err + } + + realmin := maxint(0, minint(min, max)) + realmax := minint(65535, maxint(min, max)) + + return &PortRange{Min: realmin, Max: realmax}, nil + default: + return nil, fmt.Errorf("invalid range: %s", s) + } +} + +// Contains checks whether the value is within this range. +func (ir *PortRange) Contains(value int) bool { + return value >= ir.Min && value <= ir.Max +} + +// PortSet is a set of PortRange +type PortSet []PortRange + +// ParsePortSet parses the s to a PortSet. +// The s shoud be a comma separated string. +func ParsePortSet(s string) (*PortSet, error) { + ps := &PortSet{} + + if s == "" { + return nil, errors.New("must specify at least one port") + } + + ranges := strings.Split(s, ",") + + for _, r := range ranges { + portRange, err := ParsePortRange(r) + + if err != nil { + return nil, err + } + + *ps = append(*ps, *portRange) + } + + return ps, nil +} + +// Contains checks whether the value is within this port set. +func (ps *PortSet) Contains(value int) bool { + for _, portRange := range *ps { + if portRange.Contains(value) { + return true + } + } + + return false +} + +// StringSet is a set of string. +type StringSet []string + +// ParseStringSet parses the s to a StringSet. +// The s shoud be a comma separated string. +func ParseStringSet(s string) (*StringSet, error) { + ss := &StringSet{} + if s == "" { + return nil, errors.New("cannot be empty") + } + + *ss = strings.Split(s, ",") + + return ss, nil +} + +// Contains checks whether the string subj within this StringSet. +func (ss *StringSet) Contains(subj string) bool { + for _, s := range *ss { + if glob.Glob(s, subj) { + return true + } + } + + return false +} + +// Permissions is a set of Permission. +type Permissions []Permission + +// ParsePermissions parses the s to a Permissions. +func ParsePermissions(s string) (*Permissions, error) { + ps := &Permissions{} + + if s == "" { + return &Permissions{}, nil + } + + perms := strings.Split(s, " ") + + for _, perm := range perms { + parts := strings.Split(perm, ":") + + switch len(parts) { + case 3: + actions, err := ParseStringSet(parts[0]) + + if err != nil { + return nil, fmt.Errorf("action list must look like connect,bind given: %s", parts[0]) + } + + hosts, err := ParseStringSet(parts[1]) + + if err != nil { + return nil, fmt.Errorf("hosts list must look like google.pl,*.google.com given: %s", parts[1]) + } + + ports, err := ParsePortSet(parts[2]) + + if err != nil { + return nil, fmt.Errorf("ports list must look like 80,8000-9000, given: %s", parts[2]) + } + + permission := Permission{Actions: *actions, Hosts: *hosts, Ports: *ports} + + *ps = append(*ps, permission) + default: + return nil, fmt.Errorf("permission must have format [actions]:[hosts]:[ports] given: %s", perm) + } + } + + return ps, nil +} + +// Can tests whether the given action and host:port is allowed by this Permissions. +func (ps *Permissions) Can(action string, host string, port int) bool { + for _, p := range *ps { + if p.Actions.Contains(action) && p.Hosts.Contains(host) && p.Ports.Contains(port) { + return true + } + } + + return false +} + +func minint(x, y int) int { + if x < y { + return x + } + return y +} + +func maxint(x, y int) int { + if x > y { + return x + } + return y +} + +// Can tests whether the given action and address is allowed by the whitelist and blacklist. +func Can(action string, addr string, whitelist, blacklist *Permissions) bool { + if !strings.Contains(addr, ":") { + addr = addr + ":80" + } + host, strport, err := net.SplitHostPort(addr) + + if err != nil { + return false + } + + port, err := strconv.Atoi(strport) + + if err != nil { + return false + } + + return (whitelist == nil || whitelist.Can(action, host, port)) && + (blacklist == nil || !blacklist.Can(action, host, port)) +} diff --git a/permissions_test.go b/permissions_test.go new file mode 100644 index 0000000..bc99824 --- /dev/null +++ b/permissions_test.go @@ -0,0 +1,152 @@ +package gost + +import ( + "fmt" + "testing" +) + +var portRangeTests = []struct { + in string + out *PortRange +}{ + {"1", &PortRange{Min: 1, Max: 1}}, + {"1-3", &PortRange{Min: 1, Max: 3}}, + {"3-1", &PortRange{Min: 1, Max: 3}}, + {"0-100000", &PortRange{Min: 0, Max: 65535}}, + {"*", &PortRange{Min: 0, Max: 65535}}, +} + +var stringSetTests = []struct { + in string + out *StringSet +}{ + {"*", &StringSet{"*"}}, + {"google.pl,google.com", &StringSet{"google.pl", "google.com"}}, +} + +var portSetTests = []struct { + in string + out *PortSet +}{ + {"1,3", &PortSet{PortRange{Min: 1, Max: 1}, PortRange{Min: 3, Max: 3}}}, + {"1-3,7-5", &PortSet{PortRange{Min: 1, Max: 3}, PortRange{Min: 5, Max: 7}}}, + {"0-100000", &PortSet{PortRange{Min: 0, Max: 65535}}}, + {"*", &PortSet{PortRange{Min: 0, Max: 65535}}}, +} + +var permissionsTests = []struct { + in string + out *Permissions +}{ + {"", &Permissions{}}, + {"*:*:*", &Permissions{ + Permission{ + Actions: StringSet{"*"}, + Hosts: StringSet{"*"}, + Ports: PortSet{PortRange{Min: 0, Max: 65535}}, + }, + }}, + {"bind:127.0.0.1,localhost:80,443,8000-8100 connect:*.google.pl:80,443", &Permissions{ + Permission{ + Actions: StringSet{"bind"}, + Hosts: StringSet{"127.0.0.1", "localhost"}, + Ports: PortSet{ + PortRange{Min: 80, Max: 80}, + PortRange{Min: 443, Max: 443}, + PortRange{Min: 8000, Max: 8100}, + }, + }, + Permission{ + Actions: StringSet{"connect"}, + Hosts: StringSet{"*.google.pl"}, + Ports: PortSet{ + PortRange{Min: 80, Max: 80}, + PortRange{Min: 443, Max: 443}, + }, + }, + }}, +} + +func TestPortRangeParse(t *testing.T) { + for _, test := range portRangeTests { + actual, err := ParsePortRange(test.in) + if err != nil { + t.Errorf("ParsePortRange(%q) returned error: %v", test.in, err) + } else if *actual != *test.out { + t.Errorf("ParsePortRange(%q): got %v, want %v", test.in, actual, test.out) + } + } +} + +func TestPortRangeContains(t *testing.T) { + actual, _ := ParsePortRange("5-10") + + if !actual.Contains(5) || !actual.Contains(7) || !actual.Contains(10) { + t.Errorf("5-10 should contain 5, 7 and 10") + } + + if actual.Contains(4) || actual.Contains(11) { + t.Errorf("5-10 should not contain 4, 11") + } +} + +func TestStringSetParse(t *testing.T) { + for _, test := range stringSetTests { + actual, err := ParseStringSet(test.in) + if err != nil { + t.Errorf("ParseStringSet(%q) returned error: %v", test.in, err) + } else if fmt.Sprintln(actual) != fmt.Sprintln(test.out) { + t.Errorf("ParseStringSet(%q): got %v, want %v", test.in, actual, test.out) + } + } +} + +func TestStringSetContains(t *testing.T) { + ss, _ := ParseStringSet("google.pl,*.google.com") + + if !ss.Contains("google.pl") || !ss.Contains("www.google.com") { + t.Errorf("google.pl,*.google.com should contain google.pl and www.google.com") + } + + if ss.Contains("www.google.pl") || ss.Contains("foobar.com") { + t.Errorf("google.pl,*.google.com shound not contain www.google.pl and foobar.com") + } +} + +func TestPortSetParse(t *testing.T) { + for _, test := range portSetTests { + actual, err := ParsePortSet(test.in) + if err != nil { + t.Errorf("ParsePortRange(%q) returned error: %v", test.in, err) + } else if fmt.Sprintln(actual) != fmt.Sprintln(test.out) { + t.Errorf("ParsePortRange(%q): got %v, want %v", test.in, actual, test.out) + } + } +} + +func TestPortSetContains(t *testing.T) { + actual, _ := ParsePortSet("5-10,20-30") + + if !actual.Contains(5) || !actual.Contains(7) || !actual.Contains(10) { + t.Errorf("5-10,20-30 should contain 5, 7 and 10") + } + + if !actual.Contains(20) || !actual.Contains(27) || !actual.Contains(30) { + t.Errorf("5-10,20-30 should contain 20, 27 and 30") + } + + if actual.Contains(4) || actual.Contains(11) || actual.Contains(31) { + t.Errorf("5-10,20-30 should not contain 4, 11, 31") + } +} + +func TestPermissionsParse(t *testing.T) { + for _, test := range permissionsTests { + actual, err := ParsePermissions(test.in) + if err != nil { + t.Errorf("ParsePermissions(%q) returned error: %v", test.in, err) + } else if fmt.Sprintln(actual) != fmt.Sprintln(test.out) { + t.Errorf("ParsePermissions(%q): got %v, want %v", test.in, actual, test.out) + } + } +} diff --git a/quic.go b/quic.go new file mode 100644 index 0000000..028f705 --- /dev/null +++ b/quic.go @@ -0,0 +1,378 @@ +package gost + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/tls" + "errors" + "io" + "net" + "sync" + "time" + + "github.com/go-log/log" + quic "github.com/lucas-clemente/quic-go" +) + +type quicSession struct { + conn net.Conn + session quic.Session +} + +func (session *quicSession) GetConn() (*quicConn, error) { + stream, err := session.session.OpenStreamSync(context.Background()) + if err != nil { + return nil, err + } + return &quicConn{ + Stream: stream, + laddr: session.session.LocalAddr(), + raddr: session.session.RemoteAddr(), + }, nil +} + +func (session *quicSession) Close() error { + return session.session.CloseWithError(quic.ApplicationErrorCode(0), "closed") +} + +type quicTransporter struct { + config *QUICConfig + sessionMutex sync.Mutex + sessions map[string]*quicSession +} + +// QUICTransporter creates a Transporter that is used by QUIC proxy client. +func QUICTransporter(config *QUICConfig) Transporter { + if config == nil { + config = &QUICConfig{} + } + return &quicTransporter{ + config: config, + sessions: make(map[string]*quicSession), + } +} + +func (tr *quicTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { + opts := &DialOptions{} + for _, option := range options { + option(opts) + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + session, ok := tr.sessions[addr] + if !ok { + var cc *net.UDPConn + cc, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + return + } + conn = cc + + if tr.config != nil && tr.config.Key != nil { + conn = &quicCipherConn{UDPConn: cc, key: tr.config.Key} + } + + session = &quicSession{conn: conn} + tr.sessions[addr] = session + } + return session.conn, nil +} + +func (tr *quicTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + config := tr.config + if opts.QUICConfig != nil { + config = opts.QUICConfig + } + if config.TLSConfig == nil { + config.TLSConfig = &tls.Config{InsecureSkipVerify: true} + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + timeout := opts.Timeout + if timeout <= 0 { + timeout = HandshakeTimeout + } + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + + session, ok := tr.sessions[opts.Addr] + if session != nil && session.conn != conn { + conn.Close() + return nil, errors.New("quic: unrecognized connection") + } + if !ok || session.session == nil { + s, err := tr.initSession(opts.Addr, conn, config) + if err != nil { + conn.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + session = s + tr.sessions[opts.Addr] = session + } + cc, err := session.GetConn() + if err != nil { + session.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + + return cc, nil +} + +func (tr *quicTransporter) initSession(addr string, conn net.Conn, config *QUICConfig) (*quicSession, error) { + udpConn, ok := conn.(net.PacketConn) + if !ok { + return nil, errors.New("quic: wrong connection type") + } + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + quicConfig := &quic.Config{ + HandshakeIdleTimeout: config.Timeout, + KeepAlive: config.KeepAlive, + Versions: []quic.VersionNumber{ + quic.Version1, + quic.VersionDraft29, + }, + MaxIdleTimeout: config.IdleTimeout, + } + session, err := quic.Dial(udpConn, udpAddr, addr, tlsConfigQUICALPN(config.TLSConfig), quicConfig) + if err != nil { + log.Logf("quic dial %s: %v", addr, err) + return nil, err + } + return &quicSession{conn: conn, session: session}, nil +} + +func (tr *quicTransporter) Multiplex() bool { + return true +} + +// QUICConfig is the config for QUIC client and server +type QUICConfig struct { + TLSConfig *tls.Config + Timeout time.Duration + KeepAlive bool + IdleTimeout time.Duration + Key []byte +} + +type quicListener struct { + ln quic.Listener + connChan chan net.Conn + errChan chan error +} + +// QUICListener creates a Listener for QUIC proxy server. +func QUICListener(addr string, config *QUICConfig) (Listener, error) { + if config == nil { + config = &QUICConfig{} + } + quicConfig := &quic.Config{ + HandshakeIdleTimeout: config.Timeout, + KeepAlive: config.KeepAlive, + MaxIdleTimeout: config.IdleTimeout, + Versions: []quic.VersionNumber{ + quic.Version1, + quic.VersionDraft29, + }, + } + + tlsConfig := config.TLSConfig + if tlsConfig == nil { + tlsConfig = DefaultTLSConfig + } + var conn net.PacketConn + + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + lconn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return nil, err + } + conn = lconn + + if config.Key != nil { + conn = &quicCipherConn{UDPConn: lconn, key: config.Key} + } + + ln, err := quic.Listen(conn, tlsConfigQUICALPN(tlsConfig), quicConfig) + if err != nil { + return nil, err + } + + l := &quicListener{ + ln: ln, + connChan: make(chan net.Conn, 1024), + errChan: make(chan error, 1), + } + go l.listenLoop() + + return l, nil +} + +func (l *quicListener) listenLoop() { + for { + session, err := l.ln.Accept(context.Background()) + if err != nil { + log.Log("[quic] accept:", err) + l.errChan <- err + close(l.errChan) + return + } + go l.sessionLoop(session) + } +} + +func (l *quicListener) sessionLoop(session quic.Session) { + log.Logf("[quic] %s <-> %s", session.RemoteAddr(), session.LocalAddr()) + defer log.Logf("[quic] %s >-< %s", session.RemoteAddr(), session.LocalAddr()) + + for { + stream, err := session.AcceptStream(context.Background()) + if err != nil { + log.Log("[quic] accept stream:", err) + session.CloseWithError(quic.ApplicationErrorCode(0), "closed") + return + } + + cc := &quicConn{Stream: stream, laddr: session.LocalAddr(), raddr: session.RemoteAddr()} + select { + case l.connChan <- cc: + default: + cc.Close() + log.Logf("[quic] %s - %s: connection queue is full", session.RemoteAddr(), session.LocalAddr()) + } + } +} + +func (l *quicListener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.connChan: + case err, ok = <-l.errChan: + if !ok { + err = errors.New("accpet on closed listener") + } + } + return +} + +func (l *quicListener) Addr() net.Addr { + return l.ln.Addr() +} + +func (l *quicListener) Close() error { + return l.ln.Close() +} + +type quicConn struct { + quic.Stream + laddr net.Addr + raddr net.Addr +} + +func (c *quicConn) LocalAddr() net.Addr { + return c.laddr +} + +func (c *quicConn) RemoteAddr() net.Addr { + return c.raddr +} + +type quicCipherConn struct { + *net.UDPConn + key []byte +} + +func (conn *quicCipherConn) ReadFrom(data []byte) (n int, addr net.Addr, err error) { + n, addr, err = conn.UDPConn.ReadFrom(data) + if err != nil { + return + } + b, err := conn.decrypt(data[:n]) + if err != nil { + return + } + + copy(data, b) + + return len(b), addr, nil +} + +func (conn *quicCipherConn) WriteTo(data []byte, addr net.Addr) (n int, err error) { + b, err := conn.encrypt(data) + if err != nil { + return + } + + _, err = conn.UDPConn.WriteTo(b, addr) + if err != nil { + return + } + + return len(b), nil +} + +func (conn *quicCipherConn) encrypt(data []byte) ([]byte, error) { + c, err := aes.NewCipher(conn.key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(c) + if err != nil { + return nil, err + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + + return gcm.Seal(nonce, nonce, data, nil), nil +} + +func (conn *quicCipherConn) decrypt(data []byte) ([]byte, error) { + c, err := aes.NewCipher(conn.key) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(c) + if err != nil { + return nil, err + } + + nonceSize := gcm.NonceSize() + if len(data) < nonceSize { + return nil, errors.New("ciphertext too short") + } + + nonce, ciphertext := data[:nonceSize], data[nonceSize:] + return gcm.Open(nil, nonce, ciphertext, nil) +} + +func tlsConfigQUICALPN(tlsConfig *tls.Config) *tls.Config { + if tlsConfig == nil { + panic("quic: tlsconfig is nil") + } + tlsConfigQUIC := &tls.Config{} + *tlsConfigQUIC = *tlsConfig + tlsConfigQUIC.NextProtos = []string{"http/3", "quic/v1"} + return tlsConfigQUIC +} diff --git a/quic_test.go b/quic_test.go new file mode 100644 index 0000000..3247490 --- /dev/null +++ b/quic_test.go @@ -0,0 +1,463 @@ +package gost + +import ( + "crypto/rand" + "crypto/sha256" + "fmt" + "net/http/httptest" + "net/url" + "testing" +) + +func httpOverQUICRoundtrip(targetURL string, data []byte, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := QUICListener("localhost:0", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: HTTPConnector(clientInfo), + Transporter: QUICTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestHTTPOverQUIC(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range httpProxyTests { + err := httpOverQUICRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) + if err == nil { + if tc.errStr != "" { + t.Errorf("#%d should failed with error %s", i, tc.errStr) + } + } else { + if tc.errStr == "" { + t.Errorf("#%d got error %v", i, err) + } + if err.Error() != tc.errStr { + t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) + } + } + } +} + +func BenchmarkHTTPOverQUIC(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := QUICListener("localhost:0", nil) + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: QUICTransporter(&QUICConfig{KeepAlive: true}), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkHTTPOverQUICParallel(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := QUICListener("localhost:0", nil) + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: QUICTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } + }) +} + +func socks5OverQUICRoundtrip(targetURL string, data []byte, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := QUICListener("localhost:0", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS5Connector(clientInfo), + Transporter: QUICTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS5Handler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS5OverQUIC(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range socks5ProxyTests { + err := socks5OverQUICRoundtrip(httpSrv.URL, sendData, + tc.cliUser, + tc.srvUsers, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func socks4OverQUICRoundtrip(targetURL string, data []byte) error { + ln, err := QUICListener("localhost:0", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4Connector(), + Transporter: QUICTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4OverQUIC(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4OverQUICRoundtrip(httpSrv.URL, sendData) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func socks4aOverQUICRoundtrip(targetURL string, data []byte) error { + ln, err := QUICListener("localhost:0", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4AConnector(), + Transporter: QUICTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4AOverQUIC(t *testing.T) { + + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4aOverQUICRoundtrip(httpSrv.URL, sendData) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func ssOverQUICRoundtrip(targetURL string, data []byte, + clientInfo, serverInfo *url.Userinfo) error { + + ln, err := QUICListener("localhost:0", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: ShadowConnector(clientInfo), + Transporter: QUICTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: ShadowHandler( + UsersHandlerOption(serverInfo), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSSOverQUIC(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range ssProxyTests { + err := ssOverQUICRoundtrip(httpSrv.URL, sendData, + tc.clientCipher, + tc.serverCipher, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func sniOverQUICRoundtrip(targetURL string, data []byte, host string) error { + ln, err := QUICListener("localhost:0", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: QUICTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverQUIC(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverQUICRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} + +func quicForwardTunnelRoundtrip(targetURL string, data []byte) error { + ln, err := QUICListener("localhost:0", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: QUICTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + server.Handler.Init() + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestQUICForwardTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := quicForwardTunnelRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} + +func httpOverCipherQUICRoundtrip(targetURL string, data []byte, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + sum := sha256.Sum256([]byte("12345678")) + cfg := &QUICConfig{ + Key: sum[:], + } + ln, err := QUICListener("localhost:0", cfg) + if err != nil { + return err + } + + client := &Client{ + Connector: HTTPConnector(clientInfo), + Transporter: QUICTransporter(cfg), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestHTTPOverCipherQUIC(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range httpProxyTests { + err := httpOverCipherQUICRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) + if err == nil { + if tc.errStr != "" { + t.Errorf("#%d should failed with error %s", i, tc.errStr) + } + } else { + if tc.errStr == "" { + t.Errorf("#%d got error %v", i, err) + } + if err.Error() != tc.errStr { + t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) + } + } + } +} diff --git a/redirect.go b/redirect.go new file mode 100644 index 0000000..199d5c2 --- /dev/null +++ b/redirect.go @@ -0,0 +1,242 @@ +//go:build linux +// +build linux + +package gost + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "syscall" + "time" + + "github.com/LiamHaworth/go-tproxy" + "github.com/go-log/log" +) + +type tcpRedirectHandler struct { + options *HandlerOptions +} + +// TCPRedirectHandler creates a server Handler for TCP transparent server. +func TCPRedirectHandler(opts ...HandlerOption) Handler { + h := &tcpRedirectHandler{} + h.Init(opts...) + + return h +} + +func (h *tcpRedirectHandler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} + } + + for _, opt := range options { + opt(h.options) + } +} + +func (h *tcpRedirectHandler) Handle(c net.Conn) { + conn, ok := c.(*net.TCPConn) + if !ok { + log.Log("[red-tcp] not a TCP connection") + } + + srcAddr := conn.RemoteAddr() + dstAddr, conn, err := h.getOriginalDstAddr(conn) + if err != nil { + log.Logf("[red-tcp] %s -> %s : %s", srcAddr, dstAddr, err) + return + } + defer conn.Close() + + log.Logf("[red-tcp] %s -> %s", srcAddr, dstAddr) + + cc, err := h.options.Chain.DialContext(context.Background(), + "tcp", dstAddr.String(), + RetryChainOption(h.options.Retries), + TimeoutChainOption(h.options.Timeout), + ) + if err != nil { + log.Logf("[red-tcp] %s -> %s : %s", srcAddr, dstAddr, err) + return + } + defer cc.Close() + + log.Logf("[red-tcp] %s <-> %s", srcAddr, dstAddr) + transport(conn, cc) + log.Logf("[red-tcp] %s >-< %s", srcAddr, dstAddr) +} + +func (h *tcpRedirectHandler) getOriginalDstAddr(conn *net.TCPConn) (addr net.Addr, c *net.TCPConn, err error) { + defer conn.Close() + + fc, err := conn.File() + if err != nil { + return + } + defer fc.Close() + + mreq, err := syscall.GetsockoptIPv6Mreq(int(fc.Fd()), syscall.IPPROTO_IP, 80) + if err != nil { + return + } + + // only ipv4 support + ip := net.IPv4(mreq.Multiaddr[4], mreq.Multiaddr[5], mreq.Multiaddr[6], mreq.Multiaddr[7]) + port := uint16(mreq.Multiaddr[2])<<8 + uint16(mreq.Multiaddr[3]) + addr, err = net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", ip.String(), port)) + if err != nil { + return + } + + cc, err := net.FileConn(fc) + if err != nil { + return + } + + c, ok := cc.(*net.TCPConn) + if !ok { + err = errors.New("not a TCP connection") + } + return +} + +type udpRedirectHandler struct { + options *HandlerOptions +} + +// UDPRedirectHandler creates a server Handler for UDP transparent server. +func UDPRedirectHandler(opts ...HandlerOption) Handler { + h := &udpRedirectHandler{} + h.Init(opts...) + + return h +} + +func (h *udpRedirectHandler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} + } + + for _, opt := range options { + opt(h.options) + } +} + +func (h *udpRedirectHandler) Handle(conn net.Conn) { + defer conn.Close() + + raddr, ok := conn.LocalAddr().(*net.UDPAddr) + if !ok { + log.Log("[red-udp] wrong connection type") + return + } + + cc, err := h.options.Chain.DialContext(context.Background(), + "udp", raddr.String(), + RetryChainOption(h.options.Retries), + TimeoutChainOption(h.options.Timeout), + ) + if err != nil { + log.Logf("[red-udp] %s - %s : %s", conn.RemoteAddr(), raddr, err) + return + } + defer cc.Close() + + log.Logf("[red-udp] %s <-> %s", conn.RemoteAddr(), raddr) + transport(conn, cc) + log.Logf("[red-udp] %s >-< %s", conn.RemoteAddr(), raddr) +} + +type udpRedirectListener struct { + *net.UDPConn + config *UDPListenConfig +} + +// UDPRedirectListener creates a Listener for UDP transparent proxy server. +func UDPRedirectListener(addr string, cfg *UDPListenConfig) (Listener, error) { + laddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + + ln, err := tproxy.ListenUDP("udp", laddr) + if err != nil { + return nil, err + } + + if cfg == nil { + cfg = &UDPListenConfig{} + } + return &udpRedirectListener{ + UDPConn: ln, + config: cfg, + }, nil +} + +func (l *udpRedirectListener) Accept() (conn net.Conn, err error) { + b := make([]byte, mediumBufferSize) + + n, raddr, dstAddr, err := tproxy.ReadFromUDP(l.UDPConn, b) + if err != nil { + log.Logf("[red-udp] %s : %s", l.Addr(), err) + return + } + log.Logf("[red-udp] %s: %s -> %s", l.Addr(), raddr, dstAddr) + + c, err := tproxy.DialUDP("udp", dstAddr, raddr) + if err != nil { + log.Logf("[red-udp] %s -> %s : %s", raddr, dstAddr, err) + return + } + + ttl := l.config.TTL + if ttl <= 0 { + ttl = defaultTTL + } + + conn = &udpRedirectServerConn{ + Conn: c, + buf: b[:n], + ttl: ttl, + } + return +} + +func (l *udpRedirectListener) Addr() net.Addr { + return l.UDPConn.LocalAddr() +} + +type udpRedirectServerConn struct { + net.Conn + buf []byte + ttl time.Duration + once sync.Once +} + +func (c *udpRedirectServerConn) Read(b []byte) (n int, err error) { + if c.ttl > 0 { + c.SetReadDeadline(time.Now().Add(c.ttl)) + defer c.SetReadDeadline(time.Time{}) + } + c.once.Do(func() { + n = copy(b, c.buf) + c.buf = nil + }) + + if n == 0 { + n, err = c.Conn.Read(b) + } + return +} + +func (c *udpRedirectServerConn) Write(b []byte) (n int, err error) { + if c.ttl > 0 { + c.SetWriteDeadline(time.Now().Add(c.ttl)) + defer c.SetWriteDeadline(time.Time{}) + } + return c.Conn.Write(b) +} diff --git a/redirect_other.go b/redirect_other.go new file mode 100644 index 0000000..11418cb --- /dev/null +++ b/redirect_other.go @@ -0,0 +1,58 @@ +//go:build !linux +// +build !linux + +package gost + +import ( + "errors" + "net" + + "github.com/go-log/log" +) + +type tcpRedirectHandler struct { + options *HandlerOptions +} + +// TCPRedirectHandler creates a server Handler for TCP redirect server. +func TCPRedirectHandler(opts ...HandlerOption) Handler { + h := &tcpRedirectHandler{ + options: &HandlerOptions{ + Chain: new(Chain), + }, + } + for _, opt := range opts { + opt(h.options) + } + return h +} + +func (h *tcpRedirectHandler) Init(options ...HandlerOption) { + log.Log("[red-tcp] TCP redirect is not available on the Windows platform") +} + +func (h *tcpRedirectHandler) Handle(c net.Conn) { + log.Log("[red-tcp] TCP redirect is not available on the Windows platform") + c.Close() +} + +type udpRedirectHandler struct { +} + +// UDPRedirectHandler creates a server Handler for UDP transparent server. +func UDPRedirectHandler(opts ...HandlerOption) Handler { + return &udpRedirectHandler{} +} + +func (h *udpRedirectHandler) Init(options ...HandlerOption) { +} + +func (h *udpRedirectHandler) Handle(conn net.Conn) { + log.Log("[red-udp] UDP redirect is not available on the Windows platform") + conn.Close() +} + +// UDPRedirectListener creates a Listener for UDP transparent proxy server. +func UDPRedirectListener(addr string, cfg *UDPListenConfig) (Listener, error) { + return nil, errors.New("UDP redirect is not available on the Windows platform") +} diff --git a/relay.go b/relay.go new file mode 100644 index 0000000..74423f4 --- /dev/null +++ b/relay.go @@ -0,0 +1,369 @@ +package gost + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "net/url" + "strconv" + "sync" + "time" + + "github.com/go-gost/relay" + "github.com/go-log/log" +) + +type relayConnector struct { + user *url.Userinfo + remoteAddr string +} + +// RelayConnector creates a Connector for TCP/UDP data relay. +func RelayConnector(user *url.Userinfo) Connector { + return &relayConnector{ + user: user, + } +} + +func (c *relayConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return conn, nil +} + +func (c *relayConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + opts := &ConnectOptions{} + for _, option := range options { + option(opts) + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = ConnectTimeout + } + + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + + var udp bool + if network == "udp" || network == "udp4" || network == "udp6" { + udp = true + } + + req := &relay.Request{ + Version: relay.Version1, + } + if udp { + req.Flags |= relay.FUDP + } + + if c.user != nil { + pwd, _ := c.user.Password() + req.Features = append(req.Features, &relay.UserAuthFeature{ + Username: c.user.Username(), + Password: pwd, + }) + } + if address != "" { + host, port, _ := net.SplitHostPort(address) + nport, _ := strconv.ParseUint(port, 10, 16) + if host == "" { + host = net.IPv4zero.String() + } + + if nport > 0 { + var atype uint8 + ip := net.ParseIP(host) + if ip == nil { + atype = relay.AddrDomain + } else if ip.To4() == nil { + atype = relay.AddrIPv6 + } else { + atype = relay.AddrIPv4 + } + + req.Features = append(req.Features, &relay.AddrFeature{ + AType: atype, + Host: host, + Port: uint16(nport), + }) + } + } + + rc := &relayConn{ + udp: udp, + Conn: conn, + } + + // write the header at once. + if opts.NoDelay { + if _, err := req.WriteTo(rc); err != nil { + return nil, err + } + } else { + if _, err := req.WriteTo(&rc.wbuf); err != nil { + return nil, err + } + } + + return rc, nil +} + +type relayHandler struct { + *baseForwardHandler +} + +// RelayHandler creates a server Handler for TCP/UDP relay server. +func RelayHandler(raddr string, opts ...HandlerOption) Handler { + h := &relayHandler{ + baseForwardHandler: &baseForwardHandler{ + raddr: raddr, + group: NewNodeGroup(), + options: &HandlerOptions{}, + }, + } + for _, opt := range opts { + opt(h.options) + } + return h +} + +func (h *relayHandler) Init(options ...HandlerOption) { + h.baseForwardHandler.Init(options...) +} + +func (h *relayHandler) Handle(conn net.Conn) { + defer conn.Close() + + req := &relay.Request{} + if _, err := req.ReadFrom(conn); err != nil { + log.Logf("[relay] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + + if req.Version != relay.Version1 { + log.Logf("[relay] %s - %s : bad version", conn.RemoteAddr(), conn.LocalAddr()) + return + } + + var user, pass string + var raddr string + for _, f := range req.Features { + if f.Type() == relay.FeatureUserAuth { + feature := f.(*relay.UserAuthFeature) + user, pass = feature.Username, feature.Password + } + if f.Type() == relay.FeatureAddr { + feature := f.(*relay.AddrFeature) + raddr = net.JoinHostPort(feature.Host, strconv.Itoa(int(feature.Port))) + } + } + + resp := &relay.Response{ + Version: relay.Version1, + Status: relay.StatusOK, + } + if h.options.Authenticator != nil && !h.options.Authenticator.Authenticate(user, pass) { + resp.Status = relay.StatusUnauthorized + resp.WriteTo(conn) + log.Logf("[relay] %s -> %s : %s unauthorized", conn.RemoteAddr(), conn.LocalAddr(), user) + return + } + + if raddr != "" { + if len(h.group.Nodes()) > 0 { + resp.Status = relay.StatusForbidden + resp.WriteTo(conn) + log.Logf("[relay] %s -> %s : relay to %s is forbidden", + conn.RemoteAddr(), conn.LocalAddr(), raddr) + return + } + } else { + if len(h.group.Nodes()) == 0 { + resp.Status = relay.StatusBadRequest + resp.WriteTo(conn) + log.Logf("[relay] %s -> %s : bad request, target addr is needed", + conn.RemoteAddr(), conn.LocalAddr()) + return + } + } + + udp := (req.Flags & relay.FUDP) == relay.FUDP + retries := 1 + if h.options.Chain != nil && h.options.Chain.Retries > 0 { + retries = h.options.Chain.Retries + } + if h.options.Retries > 0 { + retries = h.options.Retries + } + + network := "tcp" + if udp { + network = "udp" + } + if !Can(network, raddr, h.options.Whitelist, h.options.Blacklist) { + resp.Status = relay.StatusForbidden + resp.WriteTo(conn) + log.Logf("[relay] %s -> %s : relay to %s is forbidden", + conn.RemoteAddr(), conn.LocalAddr(), raddr) + return + } + + ctx := context.TODO() + var cc net.Conn + var node Node + var err error + for i := 0; i < retries; i++ { + if len(h.group.Nodes()) > 0 { + node, err = h.group.Next() + if err != nil { + resp.Status = relay.StatusServiceUnavailable + resp.WriteTo(conn) + log.Logf("[relay] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + raddr = node.Addr + } + + log.Logf("[relay] %s -> %s -> %s", conn.RemoteAddr(), conn.LocalAddr(), raddr) + cc, err = h.options.Chain.DialContext(ctx, + network, raddr, + RetryChainOption(h.options.Retries), + TimeoutChainOption(h.options.Timeout), + ) + if err != nil { + log.Logf("[relay] %s -> %s : %s", conn.RemoteAddr(), raddr, err) + node.MarkDead() + } else { + break + } + } + if err != nil { + resp.Status = relay.StatusServiceUnavailable + resp.WriteTo(conn) + return + } + + node.ResetDead() + defer cc.Close() + + sc := &relayConn{ + Conn: conn, + isServer: true, + udp: udp, + } + resp.WriteTo(&sc.wbuf) + conn = sc + + log.Logf("[relay] %s <-> %s", conn.RemoteAddr(), raddr) + transport(conn, cc) + log.Logf("[relay] %s >-< %s", conn.RemoteAddr(), raddr) +} + +type relayConn struct { + net.Conn + isServer bool + udp bool + wbuf bytes.Buffer + once sync.Once + headerSent bool +} + +func (c *relayConn) Read(b []byte) (n int, err error) { + c.once.Do(func() { + if c.isServer { + return + } + resp := new(relay.Response) + _, err = resp.ReadFrom(c.Conn) + if err != nil { + return + } + if resp.Version != relay.Version1 { + err = relay.ErrBadVersion + return + } + if resp.Status != relay.StatusOK { + err = fmt.Errorf("status %d", resp.Status) + return + } + }) + + if err != nil { + log.Logf("[relay] %s <- %s: %s", c.Conn.LocalAddr(), c.Conn.RemoteAddr(), err) + return + } + + if !c.udp { + return c.Conn.Read(b) + } + var bb [2]byte + _, err = io.ReadFull(c.Conn, bb[:]) + if err != nil { + return + } + dlen := int(binary.BigEndian.Uint16(bb[:])) + if len(b) >= dlen { + return io.ReadFull(c.Conn, b[:dlen]) + } + buf := make([]byte, dlen) + _, err = io.ReadFull(c.Conn, buf) + n = copy(b, buf) + return +} + +func (c *relayConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + n, err = c.Read(b) + addr = c.Conn.RemoteAddr() + return +} + +func (c *relayConn) Write(b []byte) (n int, err error) { + if len(b) > 0xFFFF { + err = errors.New("write: data maximum exceeded") + return + } + n = len(b) // force byte length consistent + if c.wbuf.Len() > 0 { + if c.udp { + var bb [2]byte + binary.BigEndian.PutUint16(bb[:2], uint16(len(b))) + c.wbuf.Write(bb[:]) + c.headerSent = true + } + c.wbuf.Write(b) // append the data to the cached header + // _, err = c.Conn.Write(c.wbuf.Bytes()) + // c.wbuf.Reset() + _, err = c.wbuf.WriteTo(c.Conn) + return + } + + if !c.udp { + return c.Conn.Write(b) + } + if !c.headerSent { + c.headerSent = true + b2 := make([]byte, len(b)+2) + copy(b2, b) + _, err = c.Conn.Write(b2) + return + } + nsize := 2 + len(b) + var buf []byte + if nsize <= mediumBufferSize { + buf = mPool.Get().([]byte) + defer mPool.Put(buf) + } else { + buf = make([]byte, nsize) + } + binary.BigEndian.PutUint16(buf[:2], uint16(len(b))) + n = copy(buf[2:], b) + _, err = c.Conn.Write(buf[:nsize]) + return +} + +func (c *relayConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + return c.Write(b) +} diff --git a/reload.go b/reload.go new file mode 100644 index 0000000..08d708a --- /dev/null +++ b/reload.go @@ -0,0 +1,65 @@ +package gost + +import ( + "io" + "os" + "time" + + "github.com/go-log/log" +) + +// Reloader is the interface for objects that support live reloading. +type Reloader interface { + Reload(r io.Reader) error + Period() time.Duration +} + +// Stoppable is the interface that indicates a Reloader can be stopped. +type Stoppable interface { + Stop() + Stopped() bool +} + +// PeriodReload reloads the config configFile periodically according to the period of the Reloader r. +func PeriodReload(r Reloader, configFile string) error { + if r == nil || configFile == "" { + return nil + } + + var lastMod time.Time + for { + if r.Period() < 0 { + log.Log("[reload] stopped:", configFile) + return nil + } + + f, err := os.Open(configFile) + if err != nil { + return err + } + + mt := lastMod + if finfo, err := f.Stat(); err == nil { + mt = finfo.ModTime() + } + + if !lastMod.IsZero() && !mt.Equal(lastMod) { + log.Log("[reload]", configFile) + if err := r.Reload(f); err != nil { + log.Logf("[reload] %s: %s", configFile, err) + } + } + f.Close() + lastMod = mt + + period := r.Period() + if period == 0 { + log.Log("[reload] disabled:", configFile) + return nil + } + if period < time.Second { + period = time.Second + } + <-time.After(period) + } +} diff --git a/resolver.go b/resolver.go new file mode 100644 index 0000000..618d724 --- /dev/null +++ b/resolver.go @@ -0,0 +1,914 @@ +package gost + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/go-log/log" + "github.com/miekg/dns" +) + +var ( + // DefaultResolverTimeout is the default timeout for name resolution. + DefaultResolverTimeout = 5 * time.Second +) + +type nameServerOptions struct { + timeout time.Duration + chain *Chain +} + +// NameServerOption allows a common way to set name server options. +type NameServerOption func(*nameServerOptions) + +// TimeoutNameServerOption sets the timeout for name server. +func TimeoutNameServerOption(timeout time.Duration) NameServerOption { + return func(opts *nameServerOptions) { + opts.timeout = timeout + } +} + +// ChainNameServerOption sets the chain for name server. +func ChainNameServerOption(chain *Chain) NameServerOption { + return func(opts *nameServerOptions) { + opts.chain = chain + } +} + +// NameServer is a name server. +// Currently supported protocol: TCP, UDP and TLS. +type NameServer struct { + Addr string + Protocol string + Hostname string // for TLS handshake verification + exchanger Exchanger + options nameServerOptions +} + +// Init initializes the name server. +func (ns *NameServer) Init(opts ...NameServerOption) error { + for _, opt := range opts { + opt(&ns.options) + } + + options := []ExchangerOption{ + TimeoutExchangerOption(ns.options.timeout), + } + protocol := strings.ToLower(ns.Protocol) + switch protocol { + case "tcp", "tcp-chain": + if protocol == "tcp-chain" { + options = append(options, ChainExchangerOption(ns.options.chain)) + } + ns.exchanger = NewDNSTCPExchanger(ns.Addr, options...) + case "tls", "tls-chain": + if protocol == "tls-chain" { + options = append(options, ChainExchangerOption(ns.options.chain)) + } + cfg := &tls.Config{ + ServerName: ns.Hostname, + } + if cfg.ServerName == "" { + cfg.InsecureSkipVerify = true + } + ns.exchanger = NewDoTExchanger(ns.Addr, cfg, options...) + case "https", "https-chain": + if protocol == "https-chain" { + options = append(options, ChainExchangerOption(ns.options.chain)) + } + u, err := url.Parse(ns.Addr) + if err != nil { + return err + } + u.Scheme = "https" + cfg := &tls.Config{ServerName: ns.Hostname} + if cfg.ServerName == "" { + cfg.InsecureSkipVerify = true + } + ns.exchanger = NewDoHExchanger(u, cfg, options...) + case "udp", "udp-chain": + fallthrough + default: + if protocol == "udp-chain" { + options = append(options, ChainExchangerOption(ns.options.chain)) + } + ns.exchanger = NewDNSExchanger(ns.Addr, options...) + } + + return nil +} + +func (ns *NameServer) String() string { + addr := ns.Addr + prot := ns.Protocol + if prot == "" { + prot = "udp" + } + return fmt.Sprintf("%s/%s", addr, prot) +} + +type resolverOptions struct { + chain *Chain + timeout time.Duration + ttl time.Duration + prefer string + srcIP net.IP +} + +// ResolverOption allows a common way to set Resolver options. +type ResolverOption func(*resolverOptions) + +// ChainResolverOption sets the chain for Resolver. +func ChainResolverOption(chain *Chain) ResolverOption { + return func(opts *resolverOptions) { + opts.chain = chain + } +} + +// TimeoutResolverOption sets the timeout for Resolver. +func TimeoutResolverOption(timeout time.Duration) ResolverOption { + return func(opts *resolverOptions) { + opts.timeout = timeout + } +} + +// TTLResolverOption sets the timeout for Resolver. +func TTLResolverOption(ttl time.Duration) ResolverOption { + return func(opts *resolverOptions) { + opts.ttl = ttl + } +} + +// PreferResolverOption sets the prefer for Resolver. +func PreferResolverOption(prefer string) ResolverOption { + return func(opts *resolverOptions) { + opts.prefer = prefer + } +} + +// SrcIPResolverOption sets the source IP for Resolver. +func SrcIPResolverOption(ip net.IP) ResolverOption { + return func(opts *resolverOptions) { + opts.srcIP = ip + } +} + +// Resolver is a name resolver for domain name. +// It contains a list of name servers. +type Resolver interface { + // Init initializes the Resolver instance. + Init(opts ...ResolverOption) error + // Resolve returns a slice of that host's IPv4 and IPv6 addresses. + Resolve(host string) ([]net.IP, error) + // Exchange performs a synchronous query, + // It sends the message query and waits for a reply. + Exchange(ctx context.Context, query []byte) (reply []byte, err error) +} + +// ReloadResolver is resolover that support live reloading. +type ReloadResolver interface { + Resolver + Reloader + Stoppable +} + +type resolver struct { + servers []NameServer + ttl time.Duration + timeout time.Duration + period time.Duration + domain string + cache *resolverCache + stopped chan struct{} + mux sync.RWMutex + prefer string // ipv4 or ipv6 + srcIP net.IP // for edns0 subnet option + options resolverOptions +} + +// NewResolver create a new Resolver with the given name servers and resolution timeout. +func NewResolver(ttl time.Duration, servers ...NameServer) ReloadResolver { + r := newResolver(ttl, servers...) + return r +} + +func newResolver(ttl time.Duration, servers ...NameServer) *resolver { + return &resolver{ + servers: servers, + cache: newResolverCache(ttl), + stopped: make(chan struct{}), + } +} + +func (r *resolver) Init(opts ...ResolverOption) error { + if r == nil { + return nil + } + + r.mux.Lock() + defer r.mux.Unlock() + + for _, opt := range opts { + opt(&r.options) + } + + timeout := r.timeout + if r.options.timeout != 0 { + timeout = r.options.timeout + } + if timeout <= 0 { + timeout = DefaultResolverTimeout + } + + if r.options.ttl != 0 { + r.ttl = r.options.ttl + } + if r.options.prefer != "" { + r.prefer = r.options.prefer + } + if r.options.srcIP != nil { + r.srcIP = r.options.srcIP + } + + var nss []NameServer + for _, ns := range r.servers { + if err := ns.Init( // init all name servers + ChainNameServerOption(r.options.chain), + TimeoutNameServerOption(timeout), + ); err != nil { + continue // ignore invalid name servers + } + nss = append(nss, ns) + } + + r.servers = nss + + return nil +} + +func (r *resolver) copyServers() []NameServer { + r.mux.RLock() + defer r.mux.RUnlock() + + servers := make([]NameServer, len(r.servers)) + for i := range r.servers { + servers[i] = r.servers[i] + } + + return servers +} + +func (r *resolver) Resolve(host string) (ips []net.IP, err error) { + r.mux.RLock() + domain := r.domain + r.mux.RUnlock() + + if ip := net.ParseIP(host); ip != nil { + return []net.IP{ip}, nil + } + + if !strings.Contains(host, ".") && domain != "" { + host = host + "." + domain + } + + ctx := context.Background() + for _, ns := range r.copyServers() { + ips, err = r.resolve(ctx, ns.exchanger, host) + if err != nil { + log.Logf("[resolver] %s via %s : %s", host, ns.String(), err) + continue + } + + if Debug { + log.Logf("[resolver] %s via %s %v", host, ns.String(), ips) + } + if len(ips) > 0 { + break + } + } + + return +} + +func (r *resolver) resolve(ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) { + if ex == nil { + return + } + + r.mux.RLock() + prefer := r.prefer + r.mux.RUnlock() + + if prefer == "ipv6" { // prefer ipv6 + mq := &dns.Msg{} + mq.SetQuestion(dns.Fqdn(host), dns.TypeAAAA) + ips, err = r.resolveIPs(ctx, ex, mq) + if err != nil || len(ips) > 0 { + return + } + } + + mq := &dns.Msg{} + mq.SetQuestion(dns.Fqdn(host), dns.TypeA) + return r.resolveIPs(ctx, ex, mq) +} + +func (r *resolver) resolveIPs(ctx context.Context, ex Exchanger, mq *dns.Msg) (ips []net.IP, err error) { + key := newResolverCacheKey(&mq.Question[0]) + mr := r.cache.loadCache(key) + if mr == nil { + r.addSubnetOpt(mq) + mr, err = r.exchangeMsg(ctx, ex, mq) + if err != nil { + return + } + r.cache.storeCache(key, mr, r.TTL()) + } + + for _, ans := range mr.Answer { + if ar, _ := ans.(*dns.AAAA); ar != nil { + ips = append(ips, ar.AAAA) + } + if ar, _ := ans.(*dns.A); ar != nil { + ips = append(ips, ar.A) + } + } + + return +} + +func (r *resolver) addSubnetOpt(m *dns.Msg) { + if m == nil || r.srcIP == nil { + return + } + opt := new(dns.OPT) + opt.Hdr.Name = "." + opt.Hdr.Rrtype = dns.TypeOPT + e := new(dns.EDNS0_SUBNET) + e.Code = dns.EDNS0SUBNET + if ip := r.srcIP.To4(); ip != nil { + e.Family = 1 + e.SourceNetmask = 32 + e.Address = ip.To4() + } else { + e.Family = 2 + e.SourceNetmask = 128 + e.Address = r.srcIP + } + opt.Option = append(opt.Option, e) + m.Extra = append(m.Extra, opt) +} + +func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, err error) { + mq := &dns.Msg{} + if err = mq.Unpack(query); err != nil { + return + } + + if len(mq.Question) == 0 { + return nil, errors.New("empty question") + } + + var mr *dns.Msg + // Only cache for single question. + if len(mq.Question) == 1 { + key := newResolverCacheKey(&mq.Question[0]) + mr = r.cache.loadCache(key) + if mr != nil { + log.Logf("[dns] exchange message %d (cached): %s", mq.Id, mq.Question[0].String()) + mr.Id = mq.Id + return mr.Pack() + } + + defer func() { + if mr != nil { + r.cache.storeCache(key, mr, r.TTL()) + } + }() + } + + r.addSubnetOpt(mq) + + for _, ns := range r.copyServers() { + log.Logf("[dns] exchange message %d via %s: %s", mq.Id, ns.String(), mq.Question[0].String()) + mr, err = r.exchangeMsg(ctx, ns.exchanger, mq) + if err == nil { + break + } + log.Logf("[dns] exchange message %d via %s: %s", mq.Id, ns.String(), err) + } + if err != nil { + return + } + return mr.Pack() +} + +func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) { + query, err := mq.Pack() + if err != nil { + return + } + reply, err := ex.Exchange(ctx, query) + if err != nil { + return + } + + mr = &dns.Msg{} + err = mr.Unpack(reply) + + return +} + +func (r *resolver) TTL() time.Duration { + r.mux.RLock() + defer r.mux.RUnlock() + return r.ttl +} + +func (r *resolver) Reload(rd io.Reader) error { + var ttl, timeout, period time.Duration + var domain, prefer string + var srcIP net.IP + var nss []NameServer + + if rd == nil || r.Stopped() { + return nil + } + + scanner := bufio.NewScanner(rd) + for scanner.Scan() { + line := scanner.Text() + ss := splitLine(line) + if len(ss) == 0 { + continue + } + + switch ss[0] { + case "timeout": // timeout option + if len(ss) > 1 { + timeout, _ = time.ParseDuration(ss[1]) + } + case "ttl": // ttl option + if len(ss) > 1 { + ttl, _ = time.ParseDuration(ss[1]) + } + case "reload": // reload option + if len(ss) > 1 { + period, _ = time.ParseDuration(ss[1]) + } + case "domain": + if len(ss) > 1 { + domain = ss[1] + } + case "search", "sortlist", "options": // we don't support these features in /etc/resolv.conf + case "prefer": + if len(ss) > 1 { + prefer = strings.ToLower(ss[1]) + } + case "ip": + if len(ss) > 1 { + srcIP = net.ParseIP(ss[1]) + } + case "nameserver": // nameserver option, compatible with /etc/resolv.conf + if len(ss) <= 1 { + break + } + ss = ss[1:] + fallthrough + default: + var ns NameServer + switch len(ss) { + case 0: + break + case 1: + ns.Addr = ss[0] + case 2: + ns.Addr = ss[0] + ns.Protocol = ss[1] + default: + ns.Addr = ss[0] + ns.Protocol = ss[1] + ns.Hostname = ss[2] + } + + if strings.HasPrefix(ns.Addr, "https") && ns.Protocol == "" { + ns.Protocol = "https" + } + nss = append(nss, ns) + } + } + + if err := scanner.Err(); err != nil { + return err + } + + r.mux.Lock() + r.ttl = ttl + r.timeout = timeout + r.domain = domain + r.period = period + r.prefer = prefer + r.srcIP = srcIP + r.servers = nss + r.mux.Unlock() + + r.Init() + + return nil +} + +func (r *resolver) Period() time.Duration { + if r.Stopped() { + return -1 + } + + r.mux.RLock() + defer r.mux.RUnlock() + + return r.period +} + +// Stop stops reloading. +func (r *resolver) Stop() { + select { + case <-r.stopped: + default: + close(r.stopped) + } +} + +// Stopped checks whether the reloader is stopped. +func (r *resolver) Stopped() bool { + select { + case <-r.stopped: + return true + default: + return false + } +} + +func (r *resolver) String() string { + if r == nil { + return "" + } + + r.mux.RLock() + defer r.mux.RUnlock() + + b := &bytes.Buffer{} + fmt.Fprintf(b, "TTL %v\n", r.ttl) + fmt.Fprintf(b, "Reload %v\n", r.period) + fmt.Fprintf(b, "Domain %v\n", r.domain) + for i := range r.servers { + fmt.Fprintln(b, r.servers[i]) + } + return b.String() +} + +type resolverCacheKey string + +// newResolverCacheKey generates resolver cache key from question of dns query. +func newResolverCacheKey(q *dns.Question) resolverCacheKey { + if q == nil { + return "" + } + key := fmt.Sprintf("%s%s.%s", q.Name, dns.Class(q.Qclass).String(), dns.Type(q.Qtype).String()) + return resolverCacheKey(key) +} + +type resolverCacheItem struct { + mr *dns.Msg + ts int64 + ttl time.Duration +} + +type resolverCache struct { + m sync.Map +} + +func newResolverCache(ttl time.Duration) *resolverCache { + return &resolverCache{} +} + +func (rc *resolverCache) loadCache(key resolverCacheKey) *dns.Msg { + v, ok := rc.m.Load(key) + if !ok { + return nil + } + + item, ok := v.(*resolverCacheItem) + if !ok { + return nil + } + + elapsed := time.Since(time.Unix(item.ts, 0)) + if item.ttl > 0 && elapsed > item.ttl { + rc.m.Delete(key) + return nil + } + for _, rr := range item.mr.Answer { + if elapsed > time.Duration(rr.Header().Ttl)*time.Second { + rc.m.Delete(key) + return nil + } + } + + if Debug { + log.Logf("[resolver] cache hit %s", key) + } + + return item.mr.Copy() +} + +func (rc *resolverCache) storeCache(key resolverCacheKey, mr *dns.Msg, ttl time.Duration) { + if key == "" || mr == nil || ttl < 0 { + return + } + + rc.m.Store(key, &resolverCacheItem{ + mr: mr.Copy(), + ts: time.Now().Unix(), + ttl: ttl, + }) + if Debug { + log.Logf("[resolver] cache store %s", key) + } +} + +// Exchanger is an interface for DNS synchronous query. +type Exchanger interface { + Exchange(ctx context.Context, query []byte) ([]byte, error) +} + +type exchangerOptions struct { + chain *Chain + timeout time.Duration +} + +// ExchangerOption allows a common way to set Exchanger options. +type ExchangerOption func(opts *exchangerOptions) + +// ChainExchangerOption sets the chain for Exchanger. +func ChainExchangerOption(chain *Chain) ExchangerOption { + return func(opts *exchangerOptions) { + opts.chain = chain + } +} + +// TimeoutExchangerOption sets the timeout for Exchanger. +func TimeoutExchangerOption(timeout time.Duration) ExchangerOption { + return func(opts *exchangerOptions) { + opts.timeout = timeout + } +} + +type dnsExchanger struct { + addr string + options exchangerOptions +} + +// NewDNSExchanger creates a DNS over UDP Exchanger +func NewDNSExchanger(addr string, opts ...ExchangerOption) Exchanger { + var options exchangerOptions + for _, opt := range opts { + opt(&options) + } + + if _, port, _ := net.SplitHostPort(addr); port == "" { + addr = net.JoinHostPort(addr, "53") + } + + return &dnsExchanger{ + addr: addr, + options: options, + } +} + +func (ex *dnsExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { + t := time.Now() + c, err := ex.options.chain.DialContext(ctx, + "udp", ex.addr, + TimeoutChainOption(ex.options.timeout), + ) + if err != nil { + return nil, err + } + c.SetDeadline(time.Now().Add(ex.options.timeout - time.Since(t))) + defer c.Close() + + conn := &dns.Conn{ + Conn: c, + } + if _, err = conn.Write(query); err != nil { + return nil, err + } + + mr, err := conn.ReadMsg() + if err != nil { + return nil, err + } + + return mr.Pack() +} + +type dnsTCPExchanger struct { + addr string + options exchangerOptions +} + +// NewDNSTCPExchanger creates a DNS over TCP Exchanger +func NewDNSTCPExchanger(addr string, opts ...ExchangerOption) Exchanger { + var options exchangerOptions + for _, opt := range opts { + opt(&options) + } + + if _, port, _ := net.SplitHostPort(addr); port == "" { + addr = net.JoinHostPort(addr, "53") + } + + return &dnsTCPExchanger{ + addr: addr, + options: options, + } +} + +func (ex *dnsTCPExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { + t := time.Now() + c, err := ex.options.chain.DialContext(ctx, + "tcp", ex.addr, + TimeoutChainOption(ex.options.timeout), + ) + if err != nil { + return nil, err + } + c.SetDeadline(time.Now().Add(ex.options.timeout - time.Since(t))) + defer c.Close() + + conn := &dns.Conn{ + Conn: c, + } + if _, err = conn.Write(query); err != nil { + return nil, err + } + + mr, err := conn.ReadMsg() + if err != nil { + return nil, err + } + + return mr.Pack() +} + +type dotExchanger struct { + addr string + tlsConfig *tls.Config + options exchangerOptions +} + +// NewDoTExchanger creates a DNS over TLS Exchanger +func NewDoTExchanger(addr string, tlsConfig *tls.Config, opts ...ExchangerOption) Exchanger { + var options exchangerOptions + for _, opt := range opts { + opt(&options) + } + + if _, port, _ := net.SplitHostPort(addr); port == "" { + addr = net.JoinHostPort(addr, "53") + } + + if tlsConfig == nil { + tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } + return &dotExchanger{ + addr: addr, + tlsConfig: tlsConfig, + options: options, + } +} + +func (ex *dotExchanger) dial(ctx context.Context, network, address string) (conn net.Conn, err error) { + conn, err = ex.options.chain.DialContext(ctx, + network, address, + TimeoutChainOption(ex.options.timeout), + ) + if err != nil { + return + } + conn = tls.Client(conn, ex.tlsConfig) + + return +} + +func (ex *dotExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { + t := time.Now() + c, err := ex.dial(ctx, "tcp", ex.addr) + if err != nil { + return nil, err + } + c.SetDeadline(time.Now().Add(ex.options.timeout - time.Since(t))) + defer c.Close() + + conn := &dns.Conn{ + Conn: c, + } + if _, err = conn.Write(query); err != nil { + return nil, err + } + + mr, err := conn.ReadMsg() + if err != nil { + return nil, err + } + + return mr.Pack() +} + +type dohExchanger struct { + endpoint *url.URL + client *http.Client + options exchangerOptions +} + +// NewDoHExchanger creates a DNS over HTTPS Exchanger +func NewDoHExchanger(urlStr *url.URL, tlsConfig *tls.Config, opts ...ExchangerOption) Exchanger { + var options exchangerOptions + for _, opt := range opts { + opt(&options) + } + ex := &dohExchanger{ + endpoint: urlStr, + options: options, + } + + ex.client = &http.Client{ + Timeout: options.timeout, + Transport: &http.Transport{ + // Proxy: ProxyFromEnvironment, + TLSClientConfig: tlsConfig, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: options.timeout, + ExpectContinueTimeout: 1 * time.Second, + DialContext: ex.dialContext, + }, + } + + return ex +} + +func (ex *dohExchanger) dialContext(ctx context.Context, network, address string) (net.Conn, error) { + return ex.options.chain.DialContext(ctx, + network, address, + TimeoutChainOption(ex.options.timeout), + ) +} + +func (ex *dohExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, "POST", ex.endpoint.String(), bytes.NewBuffer(query)) + if err != nil { + return nil, fmt.Errorf("failed to create an HTTPS request: %s", err) + } + + // req.Header.Add("Content-Type", "application/dns-udpwireformat") + req.Header.Add("Content-Type", "application/dns-message") + req.Host = ex.endpoint.Hostname() + + client := ex.client + if client == nil { + client = http.DefaultClient + } + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to perform an HTTPS request: %s", err) + } + + // Check response status code + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("returned status code %d", resp.StatusCode) + } + + // Read wireformat response from the body + buf, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read the response body: %s", err) + } + + return buf, nil +} diff --git a/resolver_test.go b/resolver_test.go new file mode 100644 index 0000000..79732ea --- /dev/null +++ b/resolver_test.go @@ -0,0 +1,270 @@ +package gost + +import ( + "bytes" + "fmt" + "io" + "net" + "testing" + "time" +) + +var dnsTests = []struct { + ns NameServer + host string + pass bool +}{ + {NameServer{Addr: "1.1.1.1"}, "192.168.1.1", true}, + {NameServer{Addr: "1.1.1.1"}, "github", true}, + {NameServer{Addr: "1.1.1.1"}, "github.com", true}, + {NameServer{Addr: "1.1.1.1:53"}, "github.com", true}, + {NameServer{Addr: "1.1.1.1:53", Protocol: "tcp"}, "github.com", true}, + {NameServer{Addr: "1.1.1.1:853", Protocol: "tls"}, "github.com", true}, + {NameServer{Addr: "1.1.1.1:853", Protocol: "tls", Hostname: "example.com"}, "github.com", false}, + {NameServer{Addr: "1.1.1.1:853", Protocol: "tls", Hostname: "cloudflare-dns.com"}, "github.com", true}, + {NameServer{Addr: "https://cloudflare-dns.com/dns-query", Protocol: "https"}, "github.com", true}, + {NameServer{Addr: "https://1.0.0.1/dns-query", Protocol: "https"}, "github.com", true}, + {NameServer{Addr: "1.1.1.1:12345"}, "github.com", false}, + {NameServer{Addr: "1.1.1.1:12345", Protocol: "tcp"}, "github.com", false}, + {NameServer{Addr: "1.1.1.1:12345", Protocol: "tls"}, "github.com", false}, + {NameServer{Addr: "https://1.0.0.1:12345/dns-query", Protocol: "https"}, "github.com", false}, +} + +func dnsResolverRoundtrip(t *testing.T, r Resolver, host string) error { + ips, err := r.Resolve(host) + t.Log(host, ips, err) + if err != nil { + return err + } + + return nil +} + +func TestDNSResolver(t *testing.T) { + for i, tc := range dnsTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + ns := tc.ns + t.Log(ns) + r := NewResolver(0, ns) + resolv := r.(*resolver) + resolv.domain = "com" + if err := r.Init(); err != nil { + t.Error("got error:", err) + } + err := dnsResolverRoundtrip(t, r, tc.host) + if err != nil { + if tc.pass { + t.Error("got error:", err) + } + } else { + if !tc.pass { + t.Error("should failed") + } + } + }) + } +} + +var resolverCacheTests = []struct { + name string + ips []net.IP + ttl time.Duration + result []net.IP +}{ + {"", nil, 0, nil}, + {"", []net.IP{net.IPv4(192, 168, 1, 1)}, 0, nil}, + {"", []net.IP{net.IPv4(192, 168, 1, 1)}, 10 * time.Second, nil}, + {"example.com", nil, 10 * time.Second, nil}, + {"example.com", []net.IP{}, 10 * time.Second, nil}, + {"example.com", []net.IP{net.IPv4(192, 168, 1, 1)}, 0, nil}, + {"example.com", []net.IP{net.IPv4(192, 168, 1, 1)}, -1, nil}, + {"example.com", []net.IP{net.IPv4(192, 168, 1, 1)}, 10 * time.Second, + []net.IP{net.IPv4(192, 168, 1, 1)}}, + {"example.com", []net.IP{net.IPv4(192, 168, 1, 1), net.IPv4(192, 168, 1, 2)}, 10 * time.Second, + []net.IP{net.IPv4(192, 168, 1, 1), net.IPv4(192, 168, 1, 2)}}, +} + +/* +func TestResolverCache(t *testing.T) { + isEqual := func(a, b []net.IP) bool { + if a == nil && b == nil { + return true + } + + if a == nil || b == nil || len(a) != len(b) { + return false + } + + for i := range a { + if !a[i].Equal(b[i]) { + return false + } + } + return true + } + for i, tc := range resolverCacheTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + r := newResolver(tc.ttl) + r.cache.storeCache(tc.name, tc.ips, tc.ttl) + ips := r.cache.loadCache(tc.name, tc.ttl) + + if !isEqual(tc.result, ips) { + t.Error("unexpected cache value:", tc.name, ips, tc.ttl) + } + }) + } +} +*/ + +var resolverReloadTests = []struct { + r io.Reader + + timeout time.Duration + ttl time.Duration + domain string + period time.Duration + ns *NameServer + + stopped bool +}{ + { + r: nil, + }, + { + r: bytes.NewBufferString(""), + }, + { + r: bytes.NewBufferString("reload 10s"), + period: 10 * time.Second, + }, + { + r: bytes.NewBufferString("timeout 10s\nreload 10s\n"), + timeout: 10 * time.Second, + period: 10 * time.Second, + }, + { + r: bytes.NewBufferString("ttl 10s\ntimeout 10s\nreload 10s\n"), + timeout: 10 * time.Second, + period: 10 * time.Second, + ttl: 10 * time.Second, + }, + { + r: bytes.NewBufferString("domain example.com\nttl 10s\ntimeout 10s\nreload 10s\n"), + timeout: 10 * time.Second, + period: 10 * time.Second, + ttl: 10 * time.Second, + domain: "example.com", + }, + { + r: bytes.NewBufferString("1.1.1.1"), + ns: &NameServer{ + Addr: "1.1.1.1", + }, + stopped: true, + }, + { + r: bytes.NewBufferString("\n# comment\ntimeout 10s\nsearch\nnameserver \nnameserver 1.1.1.1 udp"), + ns: &NameServer{ + Protocol: "udp", + Addr: "1.1.1.1", + }, + timeout: 10 * time.Second, + stopped: true, + }, + { + r: bytes.NewBufferString("1.1.1.1 tcp"), + ns: &NameServer{ + Addr: "1.1.1.1", + Protocol: "tcp", + }, + stopped: true, + }, + { + r: bytes.NewBufferString("1.1.1.1:853 tls cloudflare-dns.com"), + ns: &NameServer{ + Addr: "1.1.1.1:853", + Protocol: "tls", + Hostname: "cloudflare-dns.com", + }, + stopped: true, + }, + { + r: bytes.NewBufferString("1.1.1.1:853 tls"), + ns: &NameServer{ + Addr: "1.1.1.1:853", + Protocol: "tls", + }, + stopped: true, + }, + { + r: bytes.NewBufferString("1.0.0.1:53 https"), + stopped: true, + }, + { + r: bytes.NewBufferString("https://1.0.0.1/dns-query"), + ns: &NameServer{ + Addr: "https://1.0.0.1/dns-query", + Protocol: "https", + }, + stopped: true, + }, +} + +func TestResolverReload(t *testing.T) { + for i, tc := range resolverReloadTests { + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + r := newResolver(0) + if err := r.Reload(tc.r); err != nil { + t.Error(err) + } + t.Log(r.String()) + if r.TTL() != tc.ttl { + t.Errorf("ttl value should be %v, got %v", + tc.ttl, r.TTL()) + } + if r.Period() != tc.period { + t.Errorf("period value should be %v, got %v", + tc.period, r.period) + } + if r.domain != tc.domain { + t.Errorf("domain value should be %v, got %v", + tc.domain, r.domain) + } + + var ns *NameServer + if len(r.servers) > 0 { + ns = &r.servers[0] + } + + if !compareNameServer(ns, tc.ns) { + t.Errorf("nameserver not equal, should be %v, got %v", + tc.ns, r.servers) + } + + if tc.stopped { + r.Stop() + if r.Period() >= 0 { + t.Errorf("period of the stopped reloader should be minus value") + } + } + if r.Stopped() != tc.stopped { + t.Errorf("stopped value should be %v, got %v", + tc.stopped, r.Stopped()) + } + }) + } +} + +func compareNameServer(n1, n2 *NameServer) bool { + if n1 == n2 { + return true + } + if n1 == nil || n2 == nil { + return false + } + return n1.Addr == n2.Addr && + n1.Hostname == n2.Hostname && + n1.Protocol == n2.Protocol +} diff --git a/selector.go b/selector.go new file mode 100644 index 0000000..12545ac --- /dev/null +++ b/selector.go @@ -0,0 +1,294 @@ +package gost + +import ( + "errors" + "math/rand" + "net" + "strconv" + "sync" + "sync/atomic" + "time" +) + +var ( + // ErrNoneAvailable indicates there is no node available. + ErrNoneAvailable = errors.New("none available") +) + +// NodeSelector as a mechanism to pick nodes and mark their status. +type NodeSelector interface { + Select(nodes []Node, opts ...SelectOption) (Node, error) +} + +type defaultSelector struct { +} + +func (s *defaultSelector) Select(nodes []Node, opts ...SelectOption) (Node, error) { + sopts := SelectOptions{} + for _, opt := range opts { + opt(&sopts) + } + + for _, filter := range sopts.Filters { + nodes = filter.Filter(nodes) + } + if len(nodes) == 0 { + return Node{}, ErrNoneAvailable + } + strategy := sopts.Strategy + if strategy == nil { + strategy = &RoundStrategy{} + } + return strategy.Apply(nodes), nil +} + +// SelectOption is the option used when making a select call. +type SelectOption func(*SelectOptions) + +// SelectOptions is the options for node selection. +type SelectOptions struct { + Filters []Filter + Strategy Strategy +} + +// WithFilter adds a filter function to the list of filters +// used during the Select call. +func WithFilter(f ...Filter) SelectOption { + return func(o *SelectOptions) { + o.Filters = append(o.Filters, f...) + } +} + +// WithStrategy sets the selector strategy +func WithStrategy(s Strategy) SelectOption { + return func(o *SelectOptions) { + o.Strategy = s + } +} + +// Strategy is a selection strategy e.g random, round-robin. +type Strategy interface { + Apply([]Node) Node + String() string +} + +// NewStrategy creates a Strategy by the name s. +func NewStrategy(s string) Strategy { + switch s { + case "random": + return &RandomStrategy{} + case "fifo": + return &FIFOStrategy{} + case "round": + fallthrough + default: + return &RoundStrategy{} + } +} + +// RoundStrategy is a strategy for node selector. +// The node will be selected by round-robin algorithm. +type RoundStrategy struct { + counter uint64 +} + +// Apply applies the round-robin strategy for the nodes. +func (s *RoundStrategy) Apply(nodes []Node) Node { + if len(nodes) == 0 { + return Node{} + } + + n := atomic.AddUint64(&s.counter, 1) - 1 + return nodes[int(n%uint64(len(nodes)))] +} + +func (s *RoundStrategy) String() string { + return "round" +} + +// RandomStrategy is a strategy for node selector. +// The node will be selected randomly. +type RandomStrategy struct { + Seed int64 + rand *rand.Rand + once sync.Once + mux sync.Mutex +} + +// Apply applies the random strategy for the nodes. +func (s *RandomStrategy) Apply(nodes []Node) Node { + s.once.Do(func() { + seed := s.Seed + if seed == 0 { + seed = time.Now().UnixNano() + } + s.rand = rand.New(rand.NewSource(seed)) + }) + if len(nodes) == 0 { + return Node{} + } + + s.mux.Lock() + r := s.rand.Int() + s.mux.Unlock() + + return nodes[r%len(nodes)] +} + +func (s *RandomStrategy) String() string { + return "random" +} + +// FIFOStrategy is a strategy for node selector. +// The node will be selected from first to last, +// and will stick to the selected node until it is failed. +type FIFOStrategy struct{} + +// Apply applies the fifo strategy for the nodes. +func (s *FIFOStrategy) Apply(nodes []Node) Node { + if len(nodes) == 0 { + return Node{} + } + return nodes[0] +} + +func (s *FIFOStrategy) String() string { + return "fifo" +} + +// Filter is used to filter a node during the selection process +type Filter interface { + Filter([]Node) []Node + String() string +} + +// default options for FailFilter +const ( + DefaultMaxFails = 1 + DefaultFailTimeout = 30 * time.Second +) + +// FailFilter filters the dead node. +// A node is marked as dead if its failed count is greater than MaxFails. +type FailFilter struct { + MaxFails int + FailTimeout time.Duration +} + +// Filter filters dead nodes. +func (f *FailFilter) Filter(nodes []Node) []Node { + maxFails := f.MaxFails + if maxFails == 0 { + maxFails = DefaultMaxFails + } + failTimeout := f.FailTimeout + if failTimeout == 0 { + failTimeout = DefaultFailTimeout + } + + if len(nodes) <= 1 || maxFails < 0 { + return nodes + } + nl := []Node{} + for i := range nodes { + marker := nodes[i].marker.Clone() + // log.Logf("%s: %d/%d %v/%v", nodes[i], marker.FailCount(), f.MaxFails, marker.FailTime(), f.FailTimeout) + if marker.FailCount() < uint32(maxFails) || + time.Since(time.Unix(marker.FailTime(), 0)) >= failTimeout { + nl = append(nl, nodes[i]) + } + } + return nl +} + +func (f *FailFilter) String() string { + return "fail" +} + +// InvalidFilter filters the invalid node. +// A node is invalid if its port is invalid (negative or zero value). +type InvalidFilter struct{} + +// Filter filters invalid nodes. +func (f *InvalidFilter) Filter(nodes []Node) []Node { + nl := []Node{} + for i := range nodes { + _, sport, _ := net.SplitHostPort(nodes[i].Addr) + if port, _ := strconv.Atoi(sport); port > 0 { + nl = append(nl, nodes[i]) + } + } + return nl +} + +func (f *InvalidFilter) String() string { + return "invalid" +} + +type failMarker struct { + failTime int64 + failCount uint32 + mux sync.RWMutex +} + +func (m *failMarker) FailTime() int64 { + if m == nil { + return 0 + } + + m.mux.Lock() + defer m.mux.Unlock() + + return m.failTime +} + +func (m *failMarker) FailCount() uint32 { + if m == nil { + return 0 + } + + m.mux.Lock() + defer m.mux.Unlock() + + return m.failCount +} + +func (m *failMarker) Mark() { + if m == nil { + return + } + + m.mux.Lock() + defer m.mux.Unlock() + + m.failTime = time.Now().Unix() + m.failCount++ +} + +func (m *failMarker) Reset() { + if m == nil { + return + } + + m.mux.Lock() + defer m.mux.Unlock() + + m.failTime = 0 + m.failCount = 0 +} + +func (m *failMarker) Clone() *failMarker { + if m == nil { + return nil + } + + m.mux.RLock() + defer m.mux.RUnlock() + + fc, ft := m.failCount, m.failTime + + return &failMarker{ + failCount: fc, + failTime: ft, + } +} diff --git a/selector_test.go b/selector_test.go new file mode 100644 index 0000000..5da667c --- /dev/null +++ b/selector_test.go @@ -0,0 +1,151 @@ +package gost + +import ( + "testing" + "time" +) + +func TestRoundStrategy(t *testing.T) { + nodes := []Node{ + Node{ID: 1}, + Node{ID: 2}, + Node{ID: 3}, + } + s := NewStrategy("round") + t.Log(s.String()) + + if node := s.Apply(nil); node.ID > 0 { + t.Error("unexpected node", node.String()) + } + for i := 0; i <= len(nodes); i++ { + node := s.Apply(nodes) + if node.ID != nodes[i%len(nodes)].ID { + t.Error("unexpected node", node.String()) + } + } +} + +func TestRandomStrategy(t *testing.T) { + nodes := []Node{ + Node{ID: 1}, + Node{ID: 2}, + Node{ID: 3}, + } + s := NewStrategy("random") + t.Log(s.String()) + + if node := s.Apply(nil); node.ID > 0 { + t.Error("unexpected node", node.String()) + } + for i := 0; i <= len(nodes); i++ { + node := s.Apply(nodes) + if node.ID == 0 { + t.Error("unexpected node", node.String()) + } + } +} + +func TestFIFOStrategy(t *testing.T) { + nodes := []Node{ + Node{ID: 1}, + Node{ID: 2}, + Node{ID: 3}, + } + s := NewStrategy("fifo") + t.Log(s.String()) + + if node := s.Apply(nil); node.ID > 0 { + t.Error("unexpected node", node.String()) + } + for i := 0; i <= len(nodes); i++ { + node := s.Apply(nodes) + if node.ID != 1 { + t.Error("unexpected node", node.String()) + } + } +} + +func TestFailFilter(t *testing.T) { + nodes := []Node{ + Node{ID: 1, marker: &failMarker{}}, + Node{ID: 2, marker: &failMarker{}}, + Node{ID: 3, marker: &failMarker{}}, + } + filter := &FailFilter{} + t.Log(filter.String()) + + isEqual := func(a, b []Node) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil || len(a) != len(b) { + return false + } + + for i := range a { + if a[i].ID != b[i].ID { + return false + } + } + return true + } + if v := filter.Filter(nil); v != nil { + t.Error("unexpected node", v) + } + + if v := filter.Filter(nodes); !isEqual(v, nodes) { + t.Error("unexpected node", v) + } + + filter.MaxFails = -1 + nodes[0].MarkDead() + if v := filter.Filter(nodes); !isEqual(v, nodes) { + t.Error("unexpected node", v) + } + + filter.MaxFails = 0 + if v := filter.Filter(nodes); isEqual(v, nodes) { + t.Error("unexpected node", v) + } + + filter.FailTimeout = 5 * time.Second + if v := filter.Filter(nodes); isEqual(v, nodes) { + t.Error("unexpected node", v) + } + + nodes[1].MarkDead() + nodes[2].MarkDead() + if v := filter.Filter(nodes); len(v) > 0 { + t.Error("unexpected node", v) + } + + for i := range nodes { + nodes[i].ResetDead() + } + if v := filter.Filter(nodes); !isEqual(v, nodes) { + t.Error("unexpected node", v) + } +} + +func TestSelector(t *testing.T) { + nodes := []Node{ + Node{ID: 1, marker: &failMarker{}}, + Node{ID: 2, marker: &failMarker{}}, + Node{ID: 3, marker: &failMarker{}}, + } + selector := &defaultSelector{} + if _, err := selector.Select(nil); err != ErrNoneAvailable { + t.Error("got unexpected error:", err) + } + + if node, _ := selector.Select(nodes); node.ID != 1 { + t.Error("unexpected node:", node) + } + + if node, _ := selector.Select(nodes, + WithStrategy(NewStrategy("")), + WithFilter(&FailFilter{MaxFails: 1, FailTimeout: 3 * time.Second}), + ); node.ID != 1 { + t.Error("unexpected node:", node) + } +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..dd8d556 --- /dev/null +++ b/server.go @@ -0,0 +1,128 @@ +package gost + +import ( + "io" + "net" + "time" + + "github.com/go-log/log" +) + +// Accepter represents a network endpoint that can accept connection from peer. +type Accepter interface { + Accept() (net.Conn, error) +} + +// Server is a proxy server. +type Server struct { + Listener Listener + Handler Handler + options *ServerOptions +} + +// Init intializes server with given options. +func (s *Server) Init(opts ...ServerOption) { + if s.options == nil { + s.options = &ServerOptions{} + } + for _, opt := range opts { + opt(s.options) + } +} + +// Addr returns the address of the server +func (s *Server) Addr() net.Addr { + return s.Listener.Addr() +} + +// Close closes the server +func (s *Server) Close() error { + return s.Listener.Close() +} + +// Serve serves as a proxy server. +func (s *Server) Serve(h Handler, opts ...ServerOption) error { + s.Init(opts...) + + if s.Listener == nil { + ln, err := TCPListener("") + if err != nil { + return err + } + s.Listener = ln + } + + if h == nil { + h = s.Handler + } + if h == nil { + h = HTTPHandler() + } + + l := s.Listener + var tempDelay time.Duration + for { + conn, e := l.Accept() + if e != nil { + if ne, ok := e.(net.Error); ok && ne.Temporary() { + if tempDelay == 0 { + tempDelay = 5 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 1 * time.Second; tempDelay > max { + tempDelay = max + } + log.Logf("server: Accept error: %v; retrying in %v", e, tempDelay) + time.Sleep(tempDelay) + continue + } + return e + } + tempDelay = 0 + + go h.Handle(conn) + } +} + +// Run starts to serve. +func (s *Server) Run() error { + return s.Serve(s.Handler) +} + +// ServerOptions holds the options for Server. +type ServerOptions struct { +} + +// ServerOption allows a common way to set server options. +type ServerOption func(opts *ServerOptions) + +// Listener is a proxy server listener, just like a net.Listener. +type Listener interface { + net.Listener +} + +func transport(rw1, rw2 io.ReadWriter) error { + errc := make(chan error, 1) + go func() { + errc <- copyBuffer(rw1, rw2) + }() + + go func() { + errc <- copyBuffer(rw2, rw1) + }() + + err := <-errc + if err != nil && err == io.EOF { + err = nil + } + return err +} + +func copyBuffer(dst io.Writer, src io.Reader) error { + buf := lPool.Get().([]byte) + defer lPool.Put(buf) + + _, err := io.CopyBuffer(dst, src, buf) + return err +} diff --git a/signal.go b/signal.go new file mode 100644 index 0000000..cf7e484 --- /dev/null +++ b/signal.go @@ -0,0 +1,6 @@ +//go:build windows +// +build windows + +package gost + +func kcpSigHandler() {} diff --git a/signal_unix.go b/signal_unix.go new file mode 100644 index 0000000..491484d --- /dev/null +++ b/signal_unix.go @@ -0,0 +1,25 @@ +//go:build !windows +// +build !windows + +package gost + +import ( + "os" + "os/signal" + "syscall" + + "github.com/go-log/log" + "github.com/xtaci/kcp-go" +) + +func kcpSigHandler() { + ch := make(chan os.Signal, 1) + signal.Notify(ch, syscall.SIGUSR1) + + for { + switch <-ch { + case syscall.SIGUSR1: + log.Logf("[kcp] SNMP: %+v", kcp.DefaultSnmp.Copy()) + } + } +} diff --git a/snapcraft.yaml b/snapcraft.yaml new file mode 100644 index 0000000..325a0ce --- /dev/null +++ b/snapcraft.yaml @@ -0,0 +1,43 @@ +name: gost +type: app +version: '2.11.2' +title: GO Simple Tunnel +summary: A simple security tunnel written in golang +description: | + https://github.com/ginuerzh/gost +confinement: strict +grade: stable +base: core18 +license: MIT +parts: + gost: + plugin: nil + build-snaps: [go/1.18/stable] + source: https://github.com/ginuerzh/gost.git + source-subdir: cmd/gost + source-type: git + source-branch: '2' + build-packages: + - build-essential + override-build: | + set -ex + + echo "Starting override-build:" + pwd + cd $SNAPCRAFT_PART_BUILD + GO111MODULE=on CGO_ENABLED=0 go build --ldflags="-s -w" + ./gost -V + + echo "Installing to ${SNAPCRAFT_PART_INSTALL}..." + install -d $SNAPCRAFT_PART_INSTALL/bin + cp -v gost $SNAPCRAFT_PART_INSTALL/bin/ + + echo "Override-build done!" +apps: + gost: + command: bin/gost + plugs: + - home + - network + - network-bind + diff --git a/sni.go b/sni.go new file mode 100644 index 0000000..7d4c268 --- /dev/null +++ b/sni.go @@ -0,0 +1,350 @@ +// SNI proxy based on https://github.com/bradfitz/tcpproxy + +package gost + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "encoding/binary" + "errors" + "fmt" + "hash/crc32" + "io" + "net" + "net/http" + "strings" + "sync" + + "github.com/asaskevich/govalidator" + dissector "github.com/go-gost/tls-dissector" + "github.com/go-log/log" +) + +type sniConnector struct { + host string +} + +// SNIConnector creates a Connector for SNI proxy client. +func SNIConnector(host string) Connector { + return &sniConnector{host: host} +} + +func (c *sniConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + +func (c *sniConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + return nil, fmt.Errorf("%s unsupported", network) + } + + return &sniClientConn{addr: address, host: c.host, Conn: conn}, nil +} + +type sniHandler struct { + options *HandlerOptions +} + +// SNIHandler creates a server Handler for SNI proxy server. +func SNIHandler(opts ...HandlerOption) Handler { + h := &sniHandler{} + h.Init(opts...) + + return h +} + +func (h *sniHandler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} + } + + for _, opt := range options { + opt(h.options) + } +} + +func (h *sniHandler) Handle(conn net.Conn) { + defer conn.Close() + + br := bufio.NewReader(conn) + hdr, err := br.Peek(dissector.RecordHeaderLen) + if err != nil { + log.Logf("[sni] %s -> %s : %s", + conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + conn = &bufferdConn{br: br, Conn: conn} + + if hdr[0] != dissector.Handshake { + // We assume it is an HTTP request + req, err := http.ReadRequest(bufio.NewReader(conn)) + if err != nil { + log.Logf("[sni] %s -> %s : %s", + conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + + if !req.URL.IsAbs() && govalidator.IsDNSName(req.Host) { + req.URL.Scheme = "http" + } + + handler := &httpHandler{options: h.options} + handler.Init() + handler.handleRequest(conn, req) + return + } + + b, host, err := readClientHelloRecord(conn, "", false) + if err != nil { + log.Logf("[sni] %s -> %s : %s", + conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + + _, sport, _ := net.SplitHostPort(h.options.Host) + if sport == "" { + sport = "443" + } + host = net.JoinHostPort(host, sport) + + log.Logf("[sni] %s -> %s -> %s", + conn.RemoteAddr(), h.options.Node.String(), host) + + if !Can("tcp", host, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[sni] %s -> %s : Unauthorized to tcp connect to %s", + conn.RemoteAddr(), conn.LocalAddr(), host) + return + } + if h.options.Bypass.Contains(host) { + log.Log("[sni] %s - %s bypass %s", + conn.RemoteAddr(), conn.LocalAddr(), host) + return + } + + retries := 1 + if h.options.Chain != nil && h.options.Chain.Retries > 0 { + retries = h.options.Chain.Retries + } + if h.options.Retries > 0 { + retries = h.options.Retries + } + + var cc net.Conn + var route *Chain + for i := 0; i < retries; i++ { + route, err = h.options.Chain.selectRouteFor(host) + if err != nil { + log.Logf("[sni] %s -> %s : %s", + conn.RemoteAddr(), conn.LocalAddr(), err) + continue + } + + buf := bytes.Buffer{} + fmt.Fprintf(&buf, "%s -> %s -> ", + conn.RemoteAddr(), h.options.Node.String()) + for _, nd := range route.route { + fmt.Fprintf(&buf, "%d@%s -> ", nd.ID, nd.String()) + } + fmt.Fprintf(&buf, "%s", host) + log.Log("[route]", buf.String()) + + cc, err = route.Dial(host, + TimeoutChainOption(h.options.Timeout), + HostsChainOption(h.options.Hosts), + ResolverChainOption(h.options.Resolver), + ) + if err == nil { + break + } + log.Logf("[sni] %s -> %s : %s", + conn.RemoteAddr(), conn.LocalAddr(), err) + } + + if err != nil { + return + } + defer cc.Close() + + if _, err := cc.Write(b); err != nil { + log.Logf("[sni] %s -> %s : %s", + conn.RemoteAddr(), conn.LocalAddr(), err) + } + + log.Logf("[sni] %s <-> %s", cc.LocalAddr(), host) + transport(conn, cc) + log.Logf("[sni] %s >-< %s", cc.LocalAddr(), host) +} + +// sniSniffConn is a net.Conn that reads from r, fails on Writes, +// and crashes otherwise. +type sniSniffConn struct { + r io.Reader + net.Conn // nil; crash on any unexpected use +} + +func (c sniSniffConn) Read(p []byte) (int, error) { return c.r.Read(p) } +func (sniSniffConn) Write(p []byte) (int, error) { return 0, io.EOF } + +type sniClientConn struct { + addr string + host string + mutex sync.Mutex + obfuscated bool + net.Conn +} + +func (c *sniClientConn) Write(p []byte) (int, error) { + b, err := c.obfuscate(p) + if err != nil { + return 0, err + } + if _, err = c.Conn.Write(b); err != nil { + return 0, err + } + return len(p), nil +} + +func (c *sniClientConn) obfuscate(p []byte) ([]byte, error) { + if c.host == "" { + return p, nil + } + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.obfuscated { + return p, nil + } + + if p[0] == dissector.Handshake { + b, host, err := readClientHelloRecord(bytes.NewReader(p), c.host, true) + if err != nil { + return nil, err + } + if Debug { + log.Logf("[sni] obfuscate: %s -> %s", c.addr, host) + } + c.obfuscated = true + return b, nil + } + + buf := &bytes.Buffer{} + br := bufio.NewReader(bytes.NewReader(p)) + for { + s, err := br.ReadString('\n') + if err != nil { + if err != io.EOF { + return nil, err + } + if s != "" { + buf.Write([]byte(s)) + } + break + } + + // end of HTTP header + if s == "\r\n" { + buf.Write([]byte(s)) + // drain the remain bytes. + io.Copy(buf, br) + break + } + + if strings.HasPrefix(s, "Host") { + s = strings.TrimSpace(strings.TrimSuffix(strings.TrimPrefix(s, "Host:"), "\r\n")) + host := encodeServerName(s) + if Debug { + log.Logf("[sni] obfuscate: %s -> %s", s, c.host) + } + buf.WriteString("Host: " + c.host + "\r\n") + buf.WriteString("Gost-Target: " + host + "\r\n") + // drain the remain bytes. + io.Copy(buf, br) + break + } + buf.Write([]byte(s)) + } + c.obfuscated = true + return buf.Bytes(), nil +} + +func readClientHelloRecord(r io.Reader, host string, isClient bool) ([]byte, string, error) { + record, err := dissector.ReadRecord(r) + if err != nil { + return nil, "", err + } + clientHello := &dissector.ClientHelloMsg{} + if err := clientHello.Decode(record.Opaque); err != nil { + return nil, "", err + } + + if !isClient { + var extensions []dissector.Extension + + for _, ext := range clientHello.Extensions { + if ext.Type() == 0xFFFE { + b, _ := ext.Encode() + if host, err = decodeServerName(string(b)); err == nil { + continue + } + } + extensions = append(extensions, ext) + } + clientHello.Extensions = extensions + } + + for _, ext := range clientHello.Extensions { + if ext.Type() == dissector.ExtServerName { + snExtension := ext.(*dissector.ServerNameExtension) + if host == "" { + host = snExtension.Name + } + if isClient { + e, _ := dissector.NewExtension(0xFFFE, []byte(encodeServerName(snExtension.Name))) + clientHello.Extensions = append(clientHello.Extensions, e) + } + if host != "" { + snExtension.Name = host + } + break + } + } + record.Opaque, err = clientHello.Encode() + if err != nil { + return nil, "", err + } + + buf := &bytes.Buffer{} + if _, err := record.WriteTo(buf); err != nil { + return nil, "", err + } + + return buf.Bytes(), host, nil +} + +func encodeServerName(name string) string { + buf := &bytes.Buffer{} + binary.Write(buf, binary.BigEndian, crc32.ChecksumIEEE([]byte(name))) + buf.WriteString(base64.RawURLEncoding.EncodeToString([]byte(name))) + return base64.RawURLEncoding.EncodeToString(buf.Bytes()) +} + +func decodeServerName(s string) (string, error) { + b, err := base64.RawURLEncoding.DecodeString(s) + if err != nil { + return "", err + } + if len(b) < 4 { + return "", errors.New("invalid name") + } + v, err := base64.RawURLEncoding.DecodeString(string(b[4:])) + if err != nil { + return "", err + } + if crc32.ChecksumIEEE(v) != binary.BigEndian.Uint32(b[:4]) { + return "", errors.New("invalid name") + } + return string(v), nil +} diff --git a/sni_test.go b/sni_test.go new file mode 100644 index 0000000..0bfa994 --- /dev/null +++ b/sni_test.go @@ -0,0 +1,148 @@ +package gost + +import ( + "bufio" + "bytes" + "crypto/rand" + "crypto/tls" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" +) + +func sniRoundtrip(client *Client, server *Server, targetURL string, data []byte) (err error) { + conn, err := client.Dial(server.Addr().String()) + if err != nil { + return + } + + conn, err = client.Handshake(conn, AddrHandshakeOption(server.Addr().String())) + if err != nil { + return + } + defer conn.Close() + + u, err := url.Parse(targetURL) + if err != nil { + return + } + + conn.SetDeadline(time.Now().Add(3 * time.Second)) + defer conn.SetDeadline(time.Time{}) + + conn, err = client.Connect(conn, u.Host) + if err != nil { + return + } + + if u.Scheme == "https" { + conn = tls.Client(conn, + &tls.Config{ + InsecureSkipVerify: true, + // ServerName: u.Hostname(), + }) + u.Scheme = "http" + } + req, err := http.NewRequest( + http.MethodGet, + u.String(), + bytes.NewReader(data), + ) + if err != nil { + return + } + if err = req.WriteProxy(conn); err != nil { + return + } + resp, err := http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errors.New(resp.Status) + } + + recv, err := ioutil.ReadAll(resp.Body) + if err != nil { + return + } + + if !bytes.Equal(data, recv) { + return fmt.Errorf("data not equal") + } + + return +} + +func sniProxyRoundtrip(targetURL string, data []byte, host string) error { + ln, err := TCPListener("") + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: TCPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIProxy(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniProxyRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} diff --git a/sockopts_linux.go b/sockopts_linux.go new file mode 100644 index 0000000..f35423e --- /dev/null +++ b/sockopts_linux.go @@ -0,0 +1,7 @@ +package gost + +import "syscall" + +func setSocketMark(fd int, value int) (e error) { + return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, value) +} diff --git a/sockopts_other.go b/sockopts_other.go new file mode 100644 index 0000000..c0dd1b8 --- /dev/null +++ b/sockopts_other.go @@ -0,0 +1,7 @@ +//go:build !linux + +package gost + +func setSocketMark(fd int, value int) (e error) { + return nil +} diff --git a/socks.go b/socks.go new file mode 100644 index 0000000..fe7a7a2 --- /dev/null +++ b/socks.go @@ -0,0 +1,2079 @@ +package gost + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/url" + "strconv" + "sync" + "time" + + "github.com/go-gost/gosocks4" + "github.com/go-gost/gosocks5" + "github.com/go-log/log" + smux "github.com/xtaci/smux" +) + +const ( + // MethodTLS is an extended SOCKS5 method with tls encryption support. + MethodTLS uint8 = 0x80 + // MethodTLSAuth is an extended SOCKS5 method with tls encryption and authentication support. + MethodTLSAuth uint8 = 0x82 + // MethodMux is an extended SOCKS5 method for stream multiplexing. + MethodMux = 0x88 +) + +const ( + // CmdMuxBind is an extended SOCKS5 request CMD for + // multiplexing transport with the binding server. + CmdMuxBind uint8 = 0xF2 + // CmdUDPTun is an extended SOCKS5 request CMD for UDP over TCP. + CmdUDPTun uint8 = 0xF3 +) + +var ( + _ net.PacketConn = (*socks5UDPTunnelConn)(nil) +) + +type clientSelector struct { + methods []uint8 + User *url.Userinfo + TLSConfig *tls.Config +} + +func (selector *clientSelector) Methods() []uint8 { + if Debug { + log.Log("[socks5] methods:", selector.methods) + } + return selector.methods +} + +func (selector *clientSelector) AddMethod(methods ...uint8) { + selector.methods = append(selector.methods, methods...) +} + +func (selector *clientSelector) Select(methods ...uint8) (method uint8) { + return +} + +func (selector *clientSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) { + if Debug { + log.Log("[socks5] method selected:", method) + } + switch method { + case MethodTLS: + conn = tls.Client(conn, selector.TLSConfig) + + case gosocks5.MethodUserPass, MethodTLSAuth: + if method == MethodTLSAuth { + conn = tls.Client(conn, selector.TLSConfig) + } + + var username, password string + if selector.User != nil { + username = selector.User.Username() + password, _ = selector.User.Password() + } + + req := gosocks5.NewUserPassRequest(gosocks5.UserPassVer, username, password) + if err := req.Write(conn); err != nil { + log.Log("[socks5]", err) + return nil, err + } + if Debug { + log.Log("[socks5]", req) + } + resp, err := gosocks5.ReadUserPassResponse(conn) + if err != nil { + log.Log("[socks5]", err) + return nil, err + } + if Debug { + log.Log("[socks5]", resp) + } + if resp.Status != gosocks5.Succeeded { + return nil, gosocks5.ErrAuthFailure + } + case gosocks5.MethodNoAcceptable: + return nil, gosocks5.ErrBadMethod + } + + return conn, nil +} + +type serverSelector struct { + methods []uint8 + // Users []*url.Userinfo + Authenticator Authenticator + TLSConfig *tls.Config +} + +func (selector *serverSelector) Methods() []uint8 { + return selector.methods +} + +func (selector *serverSelector) AddMethod(methods ...uint8) { + selector.methods = append(selector.methods, methods...) +} + +func (selector *serverSelector) Select(methods ...uint8) (method uint8) { + if Debug { + log.Logf("[socks5] %d %d %v", gosocks5.Ver5, len(methods), methods) + } + method = gosocks5.MethodNoAuth + for _, m := range methods { + if m == MethodTLS { + method = m + break + } + } + + // when Authenticator is set, auth is mandatory + if selector.Authenticator != nil { + if method == gosocks5.MethodNoAuth { + method = gosocks5.MethodUserPass + } + if method == MethodTLS { + method = MethodTLSAuth + } + } + + return +} + +func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) { + if Debug { + log.Logf("[socks5] %d %d", gosocks5.Ver5, method) + } + switch method { + case MethodTLS: + conn = tls.Server(conn, selector.TLSConfig) + + case gosocks5.MethodUserPass, MethodTLSAuth: + if method == MethodTLSAuth { + conn = tls.Server(conn, selector.TLSConfig) + } + + req, err := gosocks5.ReadUserPassRequest(conn) + if err != nil { + log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err) + return nil, err + } + if Debug { + log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), req.String()) + } + + if selector.Authenticator != nil && !selector.Authenticator.Authenticate(req.Username, req.Password) { + resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure) + if err := resp.Write(conn); err != nil { + log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err) + return nil, err + } + if Debug { + log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), resp) + } + log.Logf("[socks5] %s - %s: proxy authentication required", conn.RemoteAddr(), conn.LocalAddr()) + return nil, gosocks5.ErrAuthFailure + } + + resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Succeeded) + if err := resp.Write(conn); err != nil { + log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err) + return nil, err + } + if Debug { + log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), resp) + } + case gosocks5.MethodNoAcceptable: + return nil, gosocks5.ErrBadMethod + } + + return conn, nil +} + +type socks5Connector struct { + User *url.Userinfo +} + +// SOCKS5Connector creates a connector for SOCKS5 proxy client. +// It accepts an optional auth info for SOCKS5 Username/Password Authentication. +func SOCKS5Connector(user *url.Userinfo) Connector { + return &socks5Connector{User: user} +} + +func (c *socks5Connector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + +func (c *socks5Connector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + cnr := &socks5UDPTunConnector{User: c.User} + return cnr.ConnectContext(ctx, conn, network, address, options...) + } + + opts := &ConnectOptions{} + for _, option := range options { + option(opts) + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = ConnectTimeout + } + + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + + user := opts.User + if user == nil { + user = c.User + } + cc, err := socks5Handshake(conn, + selectorSocks5HandshakeOption(opts.Selector), + userSocks5HandshakeOption(user), + noTLSSocks5HandshakeOption(opts.NoTLS), + ) + if err != nil { + return nil, err + } + conn = cc + + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + p, _ := strconv.Atoi(port) + req := gosocks5.NewRequest(gosocks5.CmdConnect, &gosocks5.Addr{ + Type: gosocks5.AddrDomain, + Host: host, + Port: uint16(p), + }) + if err := req.Write(conn); err != nil { + return nil, err + } + + if Debug { + log.Log("[socks5]", req) + } + + reply, err := gosocks5.ReadReply(conn) + if err != nil { + return nil, err + } + + if Debug { + log.Log("[socks5]", reply) + } + + if reply.Rep != gosocks5.Succeeded { + return nil, errors.New("Service unavailable") + } + + return conn, nil +} + +type socks5BindConnector struct { + User *url.Userinfo +} + +// SOCKS5BindConnector creates a connector for SOCKS5 bind. +// It accepts an optional auth info for SOCKS5 Username/Password Authentication. +func SOCKS5BindConnector(user *url.Userinfo) Connector { + return &socks5BindConnector{User: user} +} + +func (c *socks5BindConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + +func (c *socks5BindConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + return nil, fmt.Errorf("%s unsupported", network) + } + + opts := &ConnectOptions{} + for _, option := range options { + option(opts) + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = ConnectTimeout + } + + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + + user := opts.User + if user == nil { + user = c.User + } + cc, err := socks5Handshake(conn, + selectorSocks5HandshakeOption(opts.Selector), + userSocks5HandshakeOption(user), + noTLSSocks5HandshakeOption(opts.NoTLS), + ) + if err != nil { + return nil, err + } + conn = cc + + laddr, err := net.ResolveTCPAddr("tcp", address) + if err != nil { + log.Log(err) + return nil, err + } + + req := gosocks5.NewRequest(gosocks5.CmdBind, &gosocks5.Addr{ + Type: gosocks5.AddrIPv4, + Host: laddr.IP.String(), + Port: uint16(laddr.Port), + }) + + if err := req.Write(conn); err != nil { + return nil, err + } + + if Debug { + log.Log("[socks5] bind\n", req) + } + + reply, err := gosocks5.ReadReply(conn) + if err != nil { + return nil, err + } + + if Debug { + log.Log("[socks5] bind\n", reply) + } + + if reply.Rep != gosocks5.Succeeded { + log.Logf("[socks5] bind on %s failure", address) + return nil, fmt.Errorf("SOCKS5 bind on %s failure", address) + } + baddr, err := net.ResolveTCPAddr("tcp", reply.Addr.String()) + if err != nil { + return nil, err + } + log.Logf("[socks5] bind on %s OK", baddr) + + return &socks5BindConn{Conn: conn, laddr: baddr}, nil +} + +type socks5MuxBindConnector struct{} + +// Socks5MuxBindConnector creates a Connector for SOCKS5 multiplex bind client. +func Socks5MuxBindConnector() Connector { + return &socks5MuxBindConnector{} +} + +func (c *socks5MuxBindConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + +// NOTE: the conn must be *muxBindClientConn. +func (c *socks5MuxBindConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + return nil, fmt.Errorf("%s unsupported", network) + } + + accepter, ok := conn.(Accepter) + if !ok { + return nil, errors.New("wrong connection type") + } + + return accepter.Accept() +} + +type socks5MuxBindTransporter struct { + bindAddr string + sessions map[string]*muxSession // server addr to session mapping + sessionMutex sync.Mutex +} + +// SOCKS5MuxBindTransporter creates a Transporter for SOCKS5 multiplex bind client. +func SOCKS5MuxBindTransporter(bindAddr string) Transporter { + return &socks5MuxBindTransporter{ + bindAddr: bindAddr, + sessions: make(map[string]*muxSession), + } +} + +func (tr *socks5MuxBindTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { + opts := &DialOptions{} + for _, option := range options { + option(opts) + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = DialTimeout + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + session, ok := tr.sessions[addr] + if session != nil && session.IsClosed() { + delete(tr.sessions, addr) + ok = false + } + if !ok { + if opts.Chain == nil { + conn, err = net.DialTimeout("tcp", addr, timeout) + } else { + conn, err = opts.Chain.Dial(addr) + } + if err != nil { + return + } + session = &muxSession{conn: conn} + tr.sessions[addr] = session + } + return session.conn, nil +} + +func (tr *socks5MuxBindTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = HandshakeTimeout + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + + session, ok := tr.sessions[opts.Addr] + if !ok || session.session == nil { + s, err := tr.initSession(conn, tr.bindAddr, opts) + if err != nil { + conn.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + session = s + tr.sessions[opts.Addr] = session + } + + return &muxBindClientConn{session: session}, nil +} + +func (tr *socks5MuxBindTransporter) initSession(conn net.Conn, addr string, opts *HandshakeOptions) (*muxSession, error) { + if opts == nil { + opts = &HandshakeOptions{} + } + + cc, err := socks5Handshake(conn, + userSocks5HandshakeOption(opts.User), + ) + if err != nil { + return nil, err + } + conn = cc + + bindAddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + + req := gosocks5.NewRequest(CmdMuxBind, &gosocks5.Addr{ + Type: gosocks5.AddrIPv4, + Host: bindAddr.IP.String(), + Port: uint16(bindAddr.Port), + }) + + if err = req.Write(conn); err != nil { + return nil, err + } + + if Debug { + log.Log("[socks5] mbind\n", req) + } + + reply, err := gosocks5.ReadReply(conn) + if err != nil { + return nil, err + } + + if Debug { + log.Log("[socks5] mbind\n", reply) + } + + if reply.Rep != gosocks5.Succeeded { + log.Logf("[socks5] mbind on %s failure", addr) + return nil, fmt.Errorf("SOCKS5 mbind on %s failure", addr) + } + baddr, err := net.ResolveTCPAddr("tcp", reply.Addr.String()) + if err != nil { + return nil, err + } + log.Logf("[socks5] mbind on %s OK", baddr) + + // Upgrade connection to multiplex stream. + session, err := smux.Server(conn, smux.DefaultConfig()) + if err != nil { + return nil, err + } + return &muxSession{conn: conn, session: session}, nil +} + +func (tr *socks5MuxBindTransporter) Multiplex() bool { + return true +} + +type socks5UDPConnector struct { + User *url.Userinfo +} + +// SOCKS5UDPConnector creates a connector for SOCKS5 UDP relay. +// It accepts an optional auth info for SOCKS5 Username/Password Authentication. +func SOCKS5UDPConnector(user *url.Userinfo) Connector { + return &socks5UDPConnector{User: user} +} + +func (c *socks5UDPConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "udp", address, options...) +} + +func (c *socks5UDPConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "tcp", "tcp4", "tcp6": + return nil, fmt.Errorf("%s unsupported", network) + } + + opts := &ConnectOptions{} + for _, option := range options { + option(opts) + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = ConnectTimeout + } + + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + + user := opts.User + if user == nil { + user = c.User + } + cc, err := socks5Handshake(conn, + selectorSocks5HandshakeOption(opts.Selector), + userSocks5HandshakeOption(user), + noTLSSocks5HandshakeOption(opts.NoTLS), + ) + if err != nil { + return nil, err + } + conn = cc + + taddr, err := net.ResolveUDPAddr("udp", address) + if err != nil { + return nil, err + } + + req := gosocks5.NewRequest(gosocks5.CmdUdp, &gosocks5.Addr{ + Type: gosocks5.AddrIPv4, + }) + + if err := req.Write(conn); err != nil { + return nil, err + } + + if Debug { + log.Log("[socks5] udp\n", req) + } + + reply, err := gosocks5.ReadReply(conn) + if err != nil { + return nil, err + } + + if Debug { + log.Log("[socks5] udp\n", reply) + } + + if reply.Rep != gosocks5.Succeeded { + log.Logf("[socks5] udp relay failure") + return nil, fmt.Errorf("SOCKS5 udp relay failure") + } + baddr, err := net.ResolveUDPAddr("udp", reply.Addr.String()) + if err != nil { + return nil, err + } + log.Logf("[socks5] udp associate on %s OK", baddr) + + uc, err := net.DialUDP("udp", nil, baddr) + if err != nil { + return nil, err + } + // log.Logf("udp laddr:%s, raddr:%s", uc.LocalAddr(), uc.RemoteAddr()) + + return &socks5UDPConn{UDPConn: uc, taddr: taddr}, nil +} + +type socks5UDPTunConnector struct { + User *url.Userinfo +} + +// SOCKS5UDPTunConnector creates a connector for SOCKS5 UDP-over-TCP relay. +// It accepts an optional auth info for SOCKS5 Username/Password Authentication. +func SOCKS5UDPTunConnector(user *url.Userinfo) Connector { + return &socks5UDPTunConnector{User: user} +} + +func (c *socks5UDPTunConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "udp", address, options...) +} + +func (c *socks5UDPTunConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "tcp", "tcp4", "tcp6": + return nil, fmt.Errorf("%s unsupported", network) + } + + opts := &ConnectOptions{} + for _, option := range options { + option(opts) + } + + user := opts.User + if user == nil { + user = c.User + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = ConnectTimeout + } + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + + taddr, _ := net.ResolveUDPAddr("udp", address) + return newSocks5UDPTunnelConn(conn, + nil, taddr, + selectorSocks5HandshakeOption(opts.Selector), + userSocks5HandshakeOption(user), + noTLSSocks5HandshakeOption(opts.NoTLS), + ) +} + +type socks4Connector struct{} + +// SOCKS4Connector creates a Connector for SOCKS4 proxy client. +func SOCKS4Connector() Connector { + return &socks4Connector{} +} + +func (c *socks4Connector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + +func (c *socks4Connector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + return nil, fmt.Errorf("%s unsupported", network) + } + + opts := &ConnectOptions{} + for _, option := range options { + option(opts) + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = ConnectTimeout + } + + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + + taddr, err := net.ResolveTCPAddr("tcp4", address) + if err != nil { + return nil, err + } + if len(taddr.IP) == 0 { + taddr.IP = net.IPv4zero + } + + req := gosocks4.NewRequest(gosocks4.CmdConnect, + &gosocks4.Addr{ + Type: gosocks4.AddrIPv4, + Host: taddr.IP.String(), + Port: uint16(taddr.Port), + }, nil, + ) + if err := req.Write(conn); err != nil { + return nil, err + } + + if Debug { + log.Logf("[socks4] %s", req) + } + + reply, err := gosocks4.ReadReply(conn) + if err != nil { + return nil, err + } + + if Debug { + log.Logf("[socks4] %s", reply) + } + + if reply.Code != gosocks4.Granted { + return nil, fmt.Errorf("[socks4] %d", reply.Code) + } + + return conn, nil +} + +type socks4aConnector struct{} + +// SOCKS4AConnector creates a Connector for SOCKS4A proxy client. +func SOCKS4AConnector() Connector { + return &socks4aConnector{} +} + +func (c *socks4aConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + +func (c *socks4aConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + return nil, fmt.Errorf("%s unsupported", network) + } + + opts := &ConnectOptions{} + for _, option := range options { + option(opts) + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = ConnectTimeout + } + + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + p, _ := strconv.Atoi(port) + + req := gosocks4.NewRequest(gosocks4.CmdConnect, + &gosocks4.Addr{Type: gosocks4.AddrDomain, Host: host, Port: uint16(p)}, nil) + if err := req.Write(conn); err != nil { + return nil, err + } + + if Debug { + log.Logf("[socks4a] %s", req) + } + + reply, err := gosocks4.ReadReply(conn) + if err != nil { + return nil, err + } + + if Debug { + log.Logf("[socks4a] %s", reply) + } + + if reply.Code != gosocks4.Granted { + return nil, fmt.Errorf("[socks4a] %d", reply.Code) + } + + return conn, nil +} + +type socks5Handler struct { + selector *serverSelector + options *HandlerOptions +} + +// SOCKS5Handler creates a server Handler for SOCKS5 proxy server. +func SOCKS5Handler(opts ...HandlerOption) Handler { + h := &socks5Handler{} + h.Init(opts...) + + return h +} + +func (h *socks5Handler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} + } + + for _, opt := range options { + opt(h.options) + } + + tlsConfig := h.options.TLSConfig + if tlsConfig == nil { + tlsConfig = DefaultTLSConfig + } + h.selector = &serverSelector{ // socks5 server selector + // Users: h.options.Users, + Authenticator: h.options.Authenticator, + TLSConfig: tlsConfig, + } + // methods that socks5 server supported + h.selector.AddMethod( + gosocks5.MethodNoAuth, + gosocks5.MethodUserPass, + MethodTLS, + MethodTLSAuth, + ) +} + +func (h *socks5Handler) Handle(conn net.Conn) { + defer conn.Close() + + conn = gosocks5.ServerConn(conn, h.selector) + req, err := gosocks5.ReadRequest(conn) + if err != nil { + log.Logf("[socks5] %s -> %s : %s", + conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + + if Debug { + log.Logf("[socks5] %s -> %s\n%s", + conn.RemoteAddr(), conn.LocalAddr(), req) + } + switch req.Cmd { + case gosocks5.CmdConnect: + h.handleConnect(conn, req) + + case gosocks5.CmdBind: + h.handleBind(conn, req) + + case gosocks5.CmdUdp: + h.handleUDPRelay(conn, req) + + case CmdMuxBind: + h.handleMuxBind(conn, req) + + case CmdUDPTun: + h.handleUDPTunnel(conn, req) + + default: + log.Logf("[socks5] %s - %s : Unrecognized request: %d", + conn.RemoteAddr(), conn.LocalAddr(), req.Cmd) + } +} + +func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) { + host := req.Addr.String() + + log.Logf("[socks5] %s -> %s -> %s", + conn.RemoteAddr(), h.options.Node.String(), host) + + if !Can("tcp", host, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[socks5] %s - %s : Unauthorized to tcp connect to %s", + conn.RemoteAddr(), conn.LocalAddr(), host) + rep := gosocks5.NewReply(gosocks5.NotAllowed, nil) + rep.Write(conn) + if Debug { + log.Logf("[socks5] %s <- %s\n%s", + conn.RemoteAddr(), conn.LocalAddr(), rep) + } + return + } + if h.options.Bypass.Contains(host) { + log.Logf("[socks5] %s - %s : Bypass %s", + conn.RemoteAddr(), conn.LocalAddr(), host) + rep := gosocks5.NewReply(gosocks5.NotAllowed, nil) + rep.Write(conn) + if Debug { + log.Logf("[socks5] %s <- %s\n%s", + conn.RemoteAddr(), conn.LocalAddr(), rep) + } + return + } + + retries := 1 + if h.options.Chain != nil && h.options.Chain.Retries > 0 { + retries = h.options.Chain.Retries + } + if h.options.Retries > 0 { + retries = h.options.Retries + } + + var err error + var cc net.Conn + var route *Chain + for i := 0; i < retries; i++ { + route, err = h.options.Chain.selectRouteFor(host) + if err != nil { + log.Logf("[socks5] %s -> %s : %s", + conn.RemoteAddr(), conn.LocalAddr(), err) + continue + } + + buf := bytes.Buffer{} + fmt.Fprintf(&buf, "%s -> %s -> ", + conn.RemoteAddr(), h.options.Node.String()) + for _, nd := range route.route { + fmt.Fprintf(&buf, "%d@%s -> ", nd.ID, nd.String()) + } + fmt.Fprintf(&buf, "%s", host) + log.Log("[route]", buf.String()) + + cc, err = route.Dial(host, + TimeoutChainOption(h.options.Timeout), + HostsChainOption(h.options.Hosts), + ResolverChainOption(h.options.Resolver), + ) + if err == nil { + break + } + log.Logf("[socks5] %s -> %s : %s", + conn.RemoteAddr(), conn.LocalAddr(), err) + } + + if err != nil { + rep := gosocks5.NewReply(gosocks5.HostUnreachable, nil) + rep.Write(conn) + if Debug { + log.Logf("[socks5] %s <- %s\n%s", + conn.RemoteAddr(), conn.LocalAddr(), rep) + } + return + } + defer cc.Close() + + rep := gosocks5.NewReply(gosocks5.Succeeded, nil) + if err := rep.Write(conn); err != nil { + log.Logf("[socks5] %s <- %s : %s", + conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + if Debug { + log.Logf("[socks5] %s <- %s\n%s", + conn.RemoteAddr(), conn.LocalAddr(), rep) + } + log.Logf("[socks5] %s <-> %s", conn.RemoteAddr(), host) + transport(conn, cc) + log.Logf("[socks5] %s >-< %s", conn.RemoteAddr(), host) +} + +func (h *socks5Handler) handleBind(conn net.Conn, req *gosocks5.Request) { + addr := req.Addr.String() + + log.Logf("[socks5-bind] %s -> %s -> %s", + conn.RemoteAddr(), h.options.Node.String(), addr) + + if h.options.Chain.IsEmpty() { + if !Can("rtcp", addr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[socks5-bind] %s - %s : Unauthorized to tcp bind to %s", + conn.RemoteAddr(), conn.LocalAddr(), addr) + return + } + h.bindOn(conn, addr) + return + } + + cc, err := h.options.Chain.Conn() + if err != nil { + log.Logf("[socks5-bind] %s <- %s : %s", + conn.RemoteAddr(), conn.LocalAddr(), err) + reply := gosocks5.NewReply(gosocks5.Failure, nil) + reply.Write(conn) + if Debug { + log.Logf("[socks5-bind] %s <- %s\n%s", + conn.RemoteAddr(), conn.LocalAddr(), reply) + } + return + } + + // forward request + // note: this type of request forwarding is defined when starting server, + // so we don't need to authenticate it, as it's as explicit as whitelisting + defer cc.Close() + req.Write(cc) + log.Logf("[socks5-bind] %s <-> %s", conn.RemoteAddr(), addr) + transport(conn, cc) + log.Logf("[socks5-bind] %s >-< %s", conn.RemoteAddr(), addr) +} + +func (h *socks5Handler) bindOn(conn net.Conn, addr string) { + bindAddr, _ := net.ResolveTCPAddr("tcp", addr) + ln, err := net.ListenTCP("tcp", bindAddr) // strict mode: if the port already in use, it will return error + if err != nil { + log.Logf("[socks5-bind] %s -> %s : %s", + conn.RemoteAddr(), conn.LocalAddr(), err) + gosocks5.NewReply(gosocks5.Failure, nil).Write(conn) + return + } + + socksAddr := toSocksAddr(ln.Addr()) + // Issue: may not reachable when host has multi-interface + socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) + reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr) + if err := reply.Write(conn); err != nil { + log.Logf("[socks5-bind] %s <- %s : %s", + conn.RemoteAddr(), conn.LocalAddr(), err) + ln.Close() + return + } + if Debug { + log.Logf("[socks5-bind] %s <- %s\n%s", + conn.RemoteAddr(), conn.LocalAddr(), reply) + } + log.Logf("[socks5-bind] %s - %s BIND ON %s OK", + conn.RemoteAddr(), conn.LocalAddr(), socksAddr) + + var pconn net.Conn + accept := func() <-chan error { + errc := make(chan error, 1) + + go func() { + defer close(errc) + defer ln.Close() + + c, err := ln.AcceptTCP() + if err != nil { + errc <- err + return + } + pconn = c + }() + + return errc + } + + pc1, pc2 := net.Pipe() + pipe := func() <-chan error { + errc := make(chan error, 1) + + go func() { + defer close(errc) + defer pc1.Close() + + errc <- transport(conn, pc1) + }() + + return errc + } + + defer pc2.Close() + + for { + select { + case err := <-accept(): + if err != nil || pconn == nil { + log.Logf("[socks5-bind] %s <- %s : %v", conn.RemoteAddr(), addr, err) + return + } + defer pconn.Close() + + reply := gosocks5.NewReply(gosocks5.Succeeded, toSocksAddr(pconn.RemoteAddr())) + if err := reply.Write(pc2); err != nil { + log.Logf("[socks5-bind] %s <- %s : %v", conn.RemoteAddr(), addr, err) + } + if Debug { + log.Logf("[socks5-bind] %s <- %s\n%s", conn.RemoteAddr(), addr, reply) + } + log.Logf("[socks5-bind] %s <- %s PEER %s ACCEPTED", conn.RemoteAddr(), socksAddr, pconn.RemoteAddr()) + + log.Logf("[socks5-bind] %s <-> %s", conn.RemoteAddr(), pconn.RemoteAddr()) + if err = transport(pc2, pconn); err != nil { + log.Logf("[socks5-bind] %s - %s : %v", conn.RemoteAddr(), pconn.RemoteAddr(), err) + } + log.Logf("[socks5-bind] %s >-< %s", conn.RemoteAddr(), pconn.RemoteAddr()) + return + case err := <-pipe(): + if err != nil { + log.Logf("[socks5-bind] %s -> %s : %v", conn.RemoteAddr(), addr, err) + } + ln.Close() + return + } + } +} + +func (h *socks5Handler) handleUDPRelay(conn net.Conn, req *gosocks5.Request) { + addr := req.Addr.String() + if !Can("udp", addr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[socks5-udp] Unauthorized to udp connect to %s", addr) + rep := gosocks5.NewReply(gosocks5.NotAllowed, nil) + rep.Write(conn) + if Debug { + log.Logf("[socks5-udp] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep) + } + return + } + + relay, err := net.ListenUDP("udp", nil) + if err != nil { + log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + reply := gosocks5.NewReply(gosocks5.Failure, nil) + reply.Write(conn) + if Debug { + log.Logf("[socks5-udp] %s <- %s\n%s", conn.RemoteAddr(), conn.LocalAddr(), reply) + } + return + } + defer relay.Close() + + socksAddr := toSocksAddr(relay.LocalAddr()) + socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) // replace the IP to the out-going interface's + reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr) + if err := reply.Write(conn); err != nil { + log.Logf("[socks5-udp] %s <- %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + if Debug { + log.Logf("[socks5-udp] %s <- %s\n%s", conn.RemoteAddr(), conn.LocalAddr(), reply) + } + log.Logf("[socks5-udp] %s - %s BIND ON %s OK", conn.RemoteAddr(), conn.LocalAddr(), socksAddr) + + // serve as standard socks5 udp relay local <-> remote + if h.options.Chain.IsEmpty() { + peer, er := net.ListenUDP("udp", nil) + if er != nil { + log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), er) + return + } + defer peer.Close() + + go h.transportUDP(relay, peer) + log.Logf("[socks5-udp] %s <-> %s : associated on %s", conn.RemoteAddr(), conn.LocalAddr(), socksAddr) + if err := h.discardClientData(conn); err != nil { + log.Logf("[socks5-udp] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + } + log.Logf("[socks5-udp] %s >-< %s : associated on %s", conn.RemoteAddr(), conn.LocalAddr(), socksAddr) + return + } + + // forward udp local <-> tunnel + cc, err := h.options.Chain.Conn() + // connection error + if err != nil { + log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), socksAddr, err) + return + } + defer cc.Close() + + cc, err = socks5Handshake(cc, userSocks5HandshakeOption(h.options.Chain.LastNode().User)) + if err != nil { + log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), socksAddr, err) + return + } + + cc.SetWriteDeadline(time.Now().Add(WriteTimeout)) + r := gosocks5.NewRequest(CmdUDPTun, nil) + if err := r.Write(cc); err != nil { + log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), cc.RemoteAddr(), err) + return + } + cc.SetWriteDeadline(time.Time{}) + if Debug { + log.Logf("[socks5-udp] %s -> %s\n%s", conn.RemoteAddr(), cc.RemoteAddr(), r) + } + cc.SetReadDeadline(time.Now().Add(ReadTimeout)) + reply, err = gosocks5.ReadReply(cc) + if err != nil { + log.Logf("[socks5-udp] %s -> %s : %s", conn.RemoteAddr(), cc.RemoteAddr(), err) + return + } + if Debug { + log.Logf("[socks5-udp] %s <- %s\n%s", conn.RemoteAddr(), cc.RemoteAddr(), reply) + } + + if reply.Rep != gosocks5.Succeeded { + log.Logf("[socks5-udp] %s <- %s : udp associate failed", conn.RemoteAddr(), cc.RemoteAddr()) + return + } + cc.SetReadDeadline(time.Time{}) + log.Logf("[socks5-udp] %s <-> %s [tun: %s]", conn.RemoteAddr(), socksAddr, reply.Addr) + + go h.tunnelClientUDP(relay, cc) + log.Logf("[socks5-udp] %s <-> %s", conn.RemoteAddr(), socksAddr) + if err := h.discardClientData(conn); err != nil { + log.Logf("[socks5-udp] %s - %s : %s", conn.RemoteAddr(), socksAddr, err) + } + log.Logf("[socks5-udp] %s >-< %s", conn.RemoteAddr(), socksAddr) +} + +func (h *socks5Handler) discardClientData(conn net.Conn) (err error) { + b := make([]byte, tinyBufferSize) + n := 0 + for { + n, err = conn.Read(b) // discard any data from tcp connection + if err != nil { + if err == io.EOF { // disconnect normally + err = nil + } + break // client disconnected + } + log.Logf("[socks5-udp] read %d UNEXPECTED TCP data from client", n) + } + return +} + +func (h *socks5Handler) transportUDP(relay, peer net.PacketConn) (err error) { + errc := make(chan error, 2) + + var clientAddr net.Addr + + go func() { + b := mPool.Get().([]byte) + defer mPool.Put(b) + + for { + n, laddr, err := relay.ReadFrom(b) + if err != nil { + errc <- err + return + } + if clientAddr == nil { + clientAddr = laddr + } + dgram, err := gosocks5.ReadUDPDatagram(bytes.NewReader(b[:n])) + if err != nil { + errc <- err + return + } + + raddr, err := net.ResolveUDPAddr("udp", dgram.Header.Addr.String()) + if err != nil { + continue // drop silently + } + if h.options.Bypass.Contains(raddr.String()) { + log.Log("[socks5-udp] [bypass] write to", raddr) + continue // bypass + } + if _, err := peer.WriteTo(dgram.Data, raddr); err != nil { + errc <- err + return + } + if Debug { + log.Logf("[socks5-udp] %s >>> %s length: %d", relay.LocalAddr(), raddr, len(dgram.Data)) + } + } + }() + + go func() { + b := mPool.Get().([]byte) + defer mPool.Put(b) + + for { + n, raddr, err := peer.ReadFrom(b) + if err != nil { + errc <- err + return + } + if clientAddr == nil { + continue + } + if h.options.Bypass.Contains(raddr.String()) { + log.Log("[socks5-udp] [bypass] read from", raddr) + continue // bypass + } + buf := bytes.Buffer{} + dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, toSocksAddr(raddr)), b[:n]) + dgram.Write(&buf) + if _, err := relay.WriteTo(buf.Bytes(), clientAddr); err != nil { + errc <- err + return + } + if Debug { + log.Logf("[socks5-udp] %s <<< %s length: %d", relay.LocalAddr(), raddr, len(dgram.Data)) + } + } + }() + + select { + case err = <-errc: + //log.Println("w exit", err) + } + + return +} + +func (h *socks5Handler) tunnelClientUDP(uc *net.UDPConn, cc net.Conn) (err error) { + errc := make(chan error, 2) + + var clientAddr *net.UDPAddr + + go func() { + b := mPool.Get().([]byte) + defer mPool.Put(b) + + for { + n, addr, err := uc.ReadFromUDP(b) + if err != nil { + log.Logf("[udp-tun] %s <- %s : %s", cc.RemoteAddr(), addr, err) + errc <- err + return + } + + // glog.V(LDEBUG).Infof("read udp %d, % #x", n, b[:n]) + // pipe from relay to tunnel + dgram, err := gosocks5.ReadUDPDatagram(bytes.NewReader(b[:n])) + if err != nil { + errc <- err + return + } + if clientAddr == nil { + clientAddr = addr + } + raddr := dgram.Header.Addr.String() + if h.options.Bypass.Contains(raddr) { + log.Log("[udp-tun] [bypass] write to", raddr) + continue // bypass + } + dgram.Header.Rsv = uint16(len(dgram.Data)) + if err := dgram.Write(cc); err != nil { + errc <- err + return + } + if Debug { + log.Logf("[udp-tun] %s >>> %s length: %d", uc.LocalAddr(), dgram.Header.Addr, len(dgram.Data)) + } + } + }() + + go func() { + for { + dgram, err := gosocks5.ReadUDPDatagram(cc) + if err != nil { + log.Logf("[udp-tun] %s -> 0 : %s", cc.RemoteAddr(), err) + errc <- err + return + } + + // pipe from tunnel to relay + if clientAddr == nil { + continue + } + raddr := dgram.Header.Addr.String() + if h.options.Bypass.Contains(raddr) { + log.Log("[udp-tun] [bypass] read from", raddr) + continue // bypass + } + dgram.Header.Rsv = 0 + + buf := bytes.Buffer{} + dgram.Write(&buf) + if _, err := uc.WriteToUDP(buf.Bytes(), clientAddr); err != nil { + errc <- err + return + } + if Debug { + log.Logf("[udp-tun] %s <<< %s length: %d", uc.LocalAddr(), dgram.Header.Addr, len(dgram.Data)) + } + } + }() + + select { + case err = <-errc: + } + + return +} + +func (h *socks5Handler) handleUDPTunnel(conn net.Conn, req *gosocks5.Request) { + // serve tunnel udp, tunnel <-> remote, handle tunnel udp request + if h.options.Chain.IsEmpty() { + addr := req.Addr.String() + + if !Can("rudp", addr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[socks5] udp-tun Unauthorized to udp bind to %s", addr) + return + } + + bindAddr, _ := net.ResolveUDPAddr("udp", addr) + uc, err := net.ListenUDP("udp", bindAddr) + if err != nil { + log.Logf("[socks5] udp-tun %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) + return + } + defer uc.Close() + + socksAddr := toSocksAddr(uc.LocalAddr()) + socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) + reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr) + if err := reply.Write(conn); err != nil { + log.Logf("[socks5] udp-tun %s <- %s : %s", conn.RemoteAddr(), socksAddr, err) + return + } + if Debug { + log.Logf("[socks5] udp-tun %s <- %s\n%s", conn.RemoteAddr(), socksAddr, reply) + } + log.Logf("[socks5] udp-tun %s <-> %s", conn.RemoteAddr(), socksAddr) + h.tunnelServerUDP(conn, uc) + log.Logf("[socks5] udp-tun %s >-< %s", conn.RemoteAddr(), socksAddr) + return + } + + cc, err := h.options.Chain.Conn() + // connection error + if err != nil { + log.Logf("[socks5] udp-tun %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) + reply := gosocks5.NewReply(gosocks5.Failure, nil) + reply.Write(conn) + log.Logf("[socks5] udp-tun %s -> %s\n%s", conn.RemoteAddr(), req.Addr, reply) + return + } + defer cc.Close() + + cc, err = socks5Handshake(cc, userSocks5HandshakeOption(h.options.Chain.LastNode().User)) + if err != nil { + log.Logf("[socks5] udp-tun %s -> %s : %s", conn.RemoteAddr(), req.Addr, err) + return + } + // tunnel <-> tunnel, direct forwarding + // note: this type of request forwarding is defined when starting server + // so we don't need to authenticate it, as it's as explicit as whitelisting + req.Write(cc) + + log.Logf("[socks5] udp-tun %s <-> %s", conn.RemoteAddr(), cc.RemoteAddr()) + transport(conn, cc) + log.Logf("[socks5] udp-tun %s >-< %s", conn.RemoteAddr(), cc.RemoteAddr()) +} + +func (h *socks5Handler) tunnelServerUDP(cc net.Conn, pc net.PacketConn) (err error) { + errc := make(chan error, 2) + + go func() { + b := mPool.Get().([]byte) + defer mPool.Put(b) + + for { + n, addr, err := pc.ReadFrom(b) + if err != nil { + // log.Logf("[udp-tun] %s : %s", cc.RemoteAddr(), err) + errc <- err + return + } + if h.options.Bypass.Contains(addr.String()) { + log.Log("[socks5] udp-tun bypass read from", addr) + continue // bypass + } + + // pipe from peer to tunnel + dgram := gosocks5.NewUDPDatagram( + gosocks5.NewUDPHeader(uint16(n), 0, toSocksAddr(addr)), b[:n]) + if err := dgram.Write(cc); err != nil { + log.Logf("[socks5] udp-tun %s <- %s : %s", cc.RemoteAddr(), dgram.Header.Addr, err) + errc <- err + return + } + if Debug { + log.Logf("[socks5] udp-tun %s <<< %s length: %d", cc.RemoteAddr(), dgram.Header.Addr, len(dgram.Data)) + } + } + }() + + go func() { + for { + dgram, err := gosocks5.ReadUDPDatagram(cc) + if err != nil { + // log.Logf("[udp-tun] %s -> 0 : %s", cc.RemoteAddr(), err) + errc <- err + return + } + + // pipe from tunnel to peer + addr, err := net.ResolveUDPAddr("udp", dgram.Header.Addr.String()) + if err != nil { + continue // drop silently + } + if h.options.Bypass.Contains(addr.String()) { + log.Log("[socks5] udp-tun bypass write to", addr) + continue // bypass + } + if _, err := pc.WriteTo(dgram.Data, addr); err != nil { + log.Logf("[socks5] udp-tun %s -> %s : %s", cc.RemoteAddr(), addr, err) + errc <- err + return + } + if Debug { + log.Logf("[socks5] udp-tun %s >>> %s length: %d", cc.RemoteAddr(), addr, len(dgram.Data)) + } + } + }() + + select { + case err = <-errc: + } + + return +} + +func (h *socks5Handler) handleMuxBind(conn net.Conn, req *gosocks5.Request) { + if h.options.Chain.IsEmpty() { + addr := req.Addr.String() + if !Can("rtcp", addr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("Unauthorized to tcp mbind to %s", addr) + return + } + h.muxBindOn(conn, addr) + return + } + + cc, err := h.options.Chain.Conn() + if err != nil { + log.Logf("[socks5] mbind %s <- %s : %s", conn.RemoteAddr(), req.Addr, err) + reply := gosocks5.NewReply(gosocks5.Failure, nil) + reply.Write(conn) + if Debug { + log.Logf("[socks5] mbind %s <- %s\n%s", conn.RemoteAddr(), req.Addr, reply) + } + return + } + + // forward request + // note: this type of request forwarding is defined when starting server, + // so we don't need to authenticate it, as it's as explicit as whitelisting. + defer cc.Close() + req.Write(cc) + log.Logf("[socks5] mbind %s <-> %s", conn.RemoteAddr(), cc.RemoteAddr()) + transport(conn, cc) + log.Logf("[socks5] mbind %s >-< %s", conn.RemoteAddr(), cc.RemoteAddr()) +} + +func (h *socks5Handler) muxBindOn(conn net.Conn, addr string) { + bindAddr, _ := net.ResolveTCPAddr("tcp", addr) + ln, err := net.ListenTCP("tcp", bindAddr) // strict mode: if the port already in use, it will return error + if err != nil { + log.Logf("[socks5] mbind %s -> %s : %s", conn.RemoteAddr(), addr, err) + gosocks5.NewReply(gosocks5.Failure, nil).Write(conn) + return + } + defer ln.Close() + + socksAddr := toSocksAddr(ln.Addr()) + // Issue: may not reachable when host has multi-interface. + socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) + reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr) + if err := reply.Write(conn); err != nil { + log.Logf("[socks5] mbind %s <- %s : %s", conn.RemoteAddr(), addr, err) + return + } + if Debug { + log.Logf("[socks5] mbind %s <- %s\n%s", conn.RemoteAddr(), addr, reply) + } + log.Logf("[socks5] mbind %s - %s BIND ON %s OK", conn.RemoteAddr(), addr, socksAddr) + + // Upgrade connection to multiplex stream. + s, err := smux.Client(conn, smux.DefaultConfig()) + if err != nil { + log.Logf("[socks5] mbind %s - %s : %s", conn.RemoteAddr(), socksAddr, err) + return + } + + log.Logf("[socks5] mbind %s <-> %s", conn.RemoteAddr(), socksAddr) + defer log.Logf("[socks5] mbind %s >-< %s", conn.RemoteAddr(), socksAddr) + + session := &muxSession{ + conn: conn, + session: s, + } + defer session.Close() + + go func() { + for { + conn, err := session.Accept() + if err != nil { + log.Logf("[socks5] mbind accept : %v", err) + ln.Close() + return + } + conn.Close() // we do not handle incoming connection. + } + }() + + for { + cc, err := ln.Accept() + if err != nil { + log.Logf("[socks5] mbind %s <- %s : %v", conn.RemoteAddr(), socksAddr, err) + return + } + log.Logf("[socks5] mbind %s <- %s : ACCEPT peer %s", + conn.RemoteAddr(), socksAddr, cc.RemoteAddr()) + + go func(c net.Conn) { + defer c.Close() + + sc, err := session.GetConn() + if err != nil { + log.Logf("[socks5] mbind %s <- %s : %s", conn.RemoteAddr(), socksAddr, err) + return + } + defer sc.Close() + + transport(sc, c) + }(cc) + } +} + +// TODO: support ipv6 and domain +func toSocksAddr(addr net.Addr) *gosocks5.Addr { + host := "0.0.0.0" + port := 0 + if addr != nil { + h, p, _ := net.SplitHostPort(addr.String()) + host = h + port, _ = strconv.Atoi(p) + } + return &gosocks5.Addr{ + Type: gosocks5.AddrIPv4, + Host: host, + Port: uint16(port), + } +} + +type socks4Handler struct { + options *HandlerOptions +} + +// SOCKS4Handler creates a server Handler for SOCKS4(A) proxy server. +func SOCKS4Handler(opts ...HandlerOption) Handler { + h := &socks4Handler{} + h.Init(opts...) + + return h +} + +func (h *socks4Handler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} + } + + for _, opt := range options { + opt(h.options) + } +} + +func (h *socks4Handler) Handle(conn net.Conn) { + defer conn.Close() + + req, err := gosocks4.ReadRequest(conn) + if err != nil { + log.Logf("[socks4] %s -> %s : %s", + conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + + if Debug { + log.Logf("[socks4] %s -> %s\n%s", + conn.RemoteAddr(), conn.LocalAddr(), req) + } + + switch req.Cmd { + case gosocks4.CmdConnect: + h.handleConnect(conn, req) + + case gosocks4.CmdBind: + log.Logf("[socks4-bind] %s - %s", conn.RemoteAddr(), req.Addr) + h.handleBind(conn, req) + + default: + log.Logf("[socks4] %s - %s : Unrecognized request: %d", + conn.RemoteAddr(), conn.LocalAddr(), req.Cmd) + } +} + +func (h *socks4Handler) handleConnect(conn net.Conn, req *gosocks4.Request) { + addr := req.Addr.String() + + log.Logf("[socks4] %s -> %s -> %s", + conn.RemoteAddr(), h.options.Node.String(), addr) + + if !Can("tcp", addr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[socks4] %s - %s : Unauthorized to tcp connect to %s", + conn.RemoteAddr(), conn.LocalAddr(), addr) + rep := gosocks4.NewReply(gosocks4.Rejected, nil) + rep.Write(conn) + if Debug { + log.Logf("[socks4] %s <- %s\n%s", + conn.RemoteAddr(), conn.LocalAddr(), rep) + } + return + } + if h.options.Bypass.Contains(addr) { + log.Log("[socks4] %s - %s : Bypass %s", + conn.RemoteAddr(), conn.LocalAddr(), addr) + rep := gosocks4.NewReply(gosocks4.Rejected, nil) + rep.Write(conn) + if Debug { + log.Logf("[socks4] %s <- %s\n%s", + conn.RemoteAddr(), conn.LocalAddr(), rep) + } + return + } + + retries := 1 + if h.options.Chain != nil && h.options.Chain.Retries > 0 { + retries = h.options.Chain.Retries + } + if h.options.Retries > 0 { + retries = h.options.Retries + } + + var err error + var cc net.Conn + var route *Chain + for i := 0; i < retries; i++ { + route, err = h.options.Chain.selectRouteFor(addr) + if err != nil { + log.Logf("[socks4] %s -> %s : %s", + conn.RemoteAddr(), conn.LocalAddr(), err) + continue + } + + buf := bytes.Buffer{} + fmt.Fprintf(&buf, "%s -> %s -> ", + conn.RemoteAddr(), h.options.Node.String()) + for _, nd := range route.route { + fmt.Fprintf(&buf, "%d@%s -> ", nd.ID, nd.String()) + } + fmt.Fprintf(&buf, "%s", addr) + log.Log("[route]", buf.String()) + + cc, err = route.Dial(addr, + TimeoutChainOption(h.options.Timeout), + HostsChainOption(h.options.Hosts), + ResolverChainOption(h.options.Resolver), + ) + if err == nil { + break + } + log.Logf("[socks4] %s -> %s : %s", + conn.RemoteAddr(), conn.LocalAddr(), err) + } + + if err != nil { + rep := gosocks4.NewReply(gosocks4.Failed, nil) + rep.Write(conn) + if Debug { + log.Logf("[socks4] %s <- %s\n%s", + conn.RemoteAddr(), conn.LocalAddr(), rep) + } + return + } + defer cc.Close() + + rep := gosocks4.NewReply(gosocks4.Granted, nil) + if err := rep.Write(conn); err != nil { + log.Logf("[socks4] %s <- %s : %s", + conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + if Debug { + log.Logf("[socks4] %s <- %s\n%s", + conn.RemoteAddr(), conn.LocalAddr(), rep) + } + + log.Logf("[socks4] %s <-> %s", conn.RemoteAddr(), addr) + transport(conn, cc) + log.Logf("[socks4] %s >-< %s", conn.RemoteAddr(), addr) +} + +func (h *socks4Handler) handleBind(conn net.Conn, req *gosocks4.Request) { + // TODO: serve socks4 bind + if h.options.Chain.IsEmpty() { + reply := gosocks4.NewReply(gosocks4.Rejected, nil) + reply.Write(conn) + if Debug { + log.Logf("[socks4-bind] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, reply) + } + return + } + + cc, err := h.options.Chain.Conn() + // connection error + if err != nil && err != ErrEmptyChain { + log.Logf("[socks4-bind] %s <- %s : %s", conn.RemoteAddr(), req.Addr, err) + reply := gosocks4.NewReply(gosocks4.Failed, nil) + reply.Write(conn) + if Debug { + log.Logf("[socks4-bind] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, reply) + } + return + } + + defer cc.Close() + // forward request + req.Write(cc) + + log.Logf("[socks4-bind] %s <-> %s", conn.RemoteAddr(), cc.RemoteAddr()) + transport(conn, cc) + log.Logf("[socks4-bind] %s >-< %s", conn.RemoteAddr(), cc.RemoteAddr()) +} + +type socks5HandshakeOptions struct { + selector gosocks5.Selector + user *url.Userinfo + tlsConfig *tls.Config + noTLS bool +} + +type socks5HandshakeOption func(opts *socks5HandshakeOptions) + +func selectorSocks5HandshakeOption(selector gosocks5.Selector) socks5HandshakeOption { + return func(opts *socks5HandshakeOptions) { + opts.selector = selector + } +} + +func userSocks5HandshakeOption(user *url.Userinfo) socks5HandshakeOption { + return func(opts *socks5HandshakeOptions) { + opts.user = user + } +} + +func noTLSSocks5HandshakeOption(noTLS bool) socks5HandshakeOption { + return func(opts *socks5HandshakeOptions) { + opts.noTLS = noTLS + } +} + +func socks5Handshake(conn net.Conn, opts ...socks5HandshakeOption) (net.Conn, error) { + options := socks5HandshakeOptions{} + for _, opt := range opts { + opt(&options) + } + selector := options.selector + if selector == nil { + cs := &clientSelector{ + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + User: options.user, + } + cs.AddMethod( + gosocks5.MethodNoAuth, + gosocks5.MethodUserPass, + ) + if !options.noTLS { + cs.AddMethod(MethodTLS) + } + selector = cs + } + + cc := gosocks5.ClientConn(conn, selector) + if err := cc.Handleshake(); err != nil { + return nil, err + } + return cc, nil +} + +func getSocks5UDPTunnel(chain *Chain, addr net.Addr) (net.Conn, error) { + c, err := chain.Conn() + if err != nil { + return nil, err + } + + node := chain.LastNode() + conn, err := newSocks5UDPTunnelConn(c, + addr, nil, + userSocks5HandshakeOption(node.User), + noTLSSocks5HandshakeOption(node.GetBool("notls")), + ) + if err != nil { + c.Close() + } + return conn, err +} + +type socks5UDPTunnelConn struct { + net.Conn + taddr net.Addr +} + +func newSocks5UDPTunnelConn(conn net.Conn, raddr, taddr net.Addr, opts ...socks5HandshakeOption) (net.Conn, error) { + cc, err := socks5Handshake(conn, opts...) + if err != nil { + return nil, err + } + + req := gosocks5.NewRequest(CmdUDPTun, toSocksAddr(raddr)) + if err := req.Write(cc); err != nil { + return nil, err + } + if Debug { + log.Log("[socks5] udp-tun", req) + } + + reply, err := gosocks5.ReadReply(cc) + if err != nil { + return nil, err + } + + if Debug { + log.Log("[socks5] udp-tun", reply) + } + + if reply.Rep != gosocks5.Succeeded { + return nil, errors.New("socks5 UDP tunnel failure") + } + + baddr, err := net.ResolveUDPAddr("udp", reply.Addr.String()) + if err != nil { + return nil, err + } + log.Logf("[socks5] udp-tun associate on %s OK", baddr) + + return &socks5UDPTunnelConn{ + Conn: cc, + taddr: taddr, + }, nil +} + +func (c *socks5UDPTunnelConn) Read(b []byte) (n int, err error) { + n, _, err = c.ReadFrom(b) + return +} + +func (c *socks5UDPTunnelConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + dgram, err := gosocks5.ReadUDPDatagram(c.Conn) + if err != nil { + return + } + n = copy(b, dgram.Data) + addr, err = net.ResolveUDPAddr("udp", dgram.Header.Addr.String()) + return +} + +func (c *socks5UDPTunnelConn) Write(b []byte) (n int, err error) { + return c.WriteTo(b, c.taddr) +} + +func (c *socks5UDPTunnelConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(len(b)), 0, toSocksAddr(addr)), b) + if err = dgram.Write(c.Conn); err != nil { + return + } + return len(b), nil +} + +// socks5BindConn is a connection for SOCKS5 bind client. +type socks5BindConn struct { + raddr net.Addr + laddr net.Addr + net.Conn + handshaked bool + handshakeMux sync.Mutex +} + +// Handshake waits for a peer to connect to the bind port. +func (c *socks5BindConn) Handshake() (err error) { + c.handshakeMux.Lock() + defer c.handshakeMux.Unlock() + + if c.handshaked { + return nil + } + + c.handshaked = true + + rep, err := gosocks5.ReadReply(c.Conn) + if err != nil { + return fmt.Errorf("bind: read reply %v", err) + } + if rep.Rep != gosocks5.Succeeded { + return fmt.Errorf("bind: peer connect failure") + } + c.raddr, err = net.ResolveTCPAddr("tcp", rep.Addr.String()) + return +} + +func (c *socks5BindConn) Read(b []byte) (n int, err error) { + if err = c.Handshake(); err != nil { + return + } + return c.Conn.Read(b) +} + +func (c *socks5BindConn) Write(b []byte) (n int, err error) { + if err = c.Handshake(); err != nil { + return + } + return c.Conn.Write(b) +} + +func (c *socks5BindConn) LocalAddr() net.Addr { + return c.laddr +} + +func (c *socks5BindConn) RemoteAddr() net.Addr { + return c.raddr +} + +type socks5UDPConn struct { + *net.UDPConn + taddr net.Addr +} + +func (c *socks5UDPConn) Read(b []byte) (n int, err error) { + n, _, err = c.ReadFrom(b) + return +} + +func (c *socks5UDPConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + data := mPool.Get().([]byte) + defer mPool.Put(data) + + n, err = c.UDPConn.Read(data) + if err != nil { + return + } + dg, err := gosocks5.ReadUDPDatagram(bytes.NewReader(data[:n])) + if err != nil { + return + } + + n = copy(b, dg.Data) + addr, err = net.ResolveUDPAddr("udp", dg.Header.Addr.String()) + + return +} + +func (c *socks5UDPConn) Write(b []byte) (int, error) { + return c.WriteTo(b, c.taddr) +} + +func (c *socks5UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { + adr, err := gosocks5.NewAddr(addr.String()) + if err != nil { + return 0, err + } + h := gosocks5.NewUDPHeader(0, 0, adr) + dg := gosocks5.NewUDPDatagram(h, b) + if err = dg.Write(c.UDPConn); err != nil { + return 0, err + } + return len(b), nil +} + +// a dummy client conn for multiplex bind used by SOCKS5 multiplex bind client connector +type muxBindClientConn struct { + nopConn + session *muxSession +} + +func (c *muxBindClientConn) Accept() (net.Conn, error) { + return c.session.Accept() +} diff --git a/socks_test.go b/socks_test.go new file mode 100644 index 0000000..c88d94f --- /dev/null +++ b/socks_test.go @@ -0,0 +1,790 @@ +package gost + +import ( + "bytes" + "crypto/rand" + "fmt" + "net" + "net/http/httptest" + "net/url" + "testing" + "time" +) + +var socks5ProxyTests = []struct { + cliUser *url.Userinfo + srvUsers []*url.Userinfo + pass bool +}{ + {nil, nil, true}, + {nil, []*url.Userinfo{url.User("admin")}, false}, + {nil, []*url.Userinfo{url.UserPassword("", "123456")}, false}, + {url.User("admin"), []*url.Userinfo{url.User("test")}, false}, + {url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "123456")}, false}, + {url.User("admin"), []*url.Userinfo{url.User("admin")}, true}, + {url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, true}, + {url.UserPassword("admin", "123456"), nil, true}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, true}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, false}, + {url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, true}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, true}, + {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("user", "pass"), url.UserPassword("admin", "123456")}, true}, +} + +func socks5ProxyRoundtrip(targetURL string, data []byte, clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + ln, err := TCPListener("") + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS5Connector(clientInfo), + Transporter: TCPTransporter(), + } + + server := &Server{ + Handler: SOCKS5Handler(UsersHandlerOption(serverInfo...)), + Listener: ln, + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS5Proxy(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range socks5ProxyTests { + err := socks5ProxyRoundtrip(httpSrv.URL, sendData, + tc.cliUser, + tc.srvUsers, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func BenchmarkSOCKS5Proxy(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: SOCKS5Connector(url.UserPassword("admin", "123456")), + Transporter: TCPTransporter(), + } + + server := &Server{ + Handler: SOCKS5Handler(UsersHandlerOption(url.UserPassword("admin", "123456"))), + Listener: ln, + } + + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkSOCKS5ProxyParallel(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: SOCKS5Connector(url.UserPassword("admin", "123456")), + Transporter: TCPTransporter(), + } + + server := &Server{ + Handler: SOCKS5Handler(UsersHandlerOption(url.UserPassword("admin", "123456"))), + Listener: ln, + } + + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } + }) +} + +func socks4ProxyRoundtrip(targetURL string, data []byte) error { + ln, err := TCPListener("") + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4Connector(), + Transporter: TCPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4Proxy(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4ProxyRoundtrip(httpSrv.URL, sendData) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func BenchmarkSOCKS4Proxy(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: SOCKS4Connector(), + Transporter: TCPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkSOCKS4ProxyParallel(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: SOCKS4Connector(), + Transporter: TCPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } + }) +} + +func socks4aProxyRoundtrip(targetURL string, data []byte) error { + ln, err := TCPListener("") + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4AConnector(), + Transporter: TCPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4AProxy(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4aProxyRoundtrip(httpSrv.URL, sendData) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func BenchmarkSOCKS4AProxy(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: SOCKS4AConnector(), + Transporter: TCPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkSOCKS4AProxyParallel(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: SOCKS4AConnector(), + Transporter: TCPTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } + }) +} + +func socks5BindRoundtrip(t *testing.T, targetURL string, data []byte) (err error) { + ln, err := TCPListener("") + if err != nil { + return + } + + client := &Client{ + Connector: SOCKS5BindConnector(url.UserPassword("admin", "123456")), + Transporter: TCPTransporter(), + } + + server := &Server{ + Handler: SOCKS5Handler(UsersHandlerOption(url.UserPassword("admin", "123456"))), + Listener: ln, + } + + go server.Run() + defer server.Close() + + conn, err := proxyConn(client, server) + if err != nil { + return + } + defer conn.Close() + + conn, err = client.Connect(conn, "") + if err != nil { + return + } + + cc, err := net.Dial("tcp", conn.LocalAddr().String()) + if err != nil { + return + } + defer cc.Close() + + if err = conn.(*socks5BindConn).Handshake(); err != nil { + return + } + + u, err := url.Parse(targetURL) + if err != nil { + return + } + hc, err := net.Dial("tcp", u.Host) + if err != nil { + return + } + go transport(hc, conn) + + return httpRoundtrip(cc, targetURL, data) +} + +func TestSOCKS5Bind(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + if err := socks5BindRoundtrip(t, httpSrv.URL, sendData); err != nil { + t.Errorf("got error: %v", err) + } +} + +func socks5MuxBindRoundtrip(t *testing.T, targetURL string, data []byte) (err error) { + ln, err := TCPListener("") + if err != nil { + return + } + + l, err := net.Listen("tcp", "") + if err != nil { + return err + } + bindAddr := l.Addr().String() + l.Close() + + client := &Client{ + Connector: Socks5MuxBindConnector(), + Transporter: SOCKS5MuxBindTransporter(bindAddr), + } + + server := &Server{ + Handler: SOCKS5Handler(UsersHandlerOption(url.UserPassword("admin", "123456"))), + Listener: ln, + } + + go server.Run() + defer server.Close() + + return muxBindRoundtrip(client, server, bindAddr, targetURL, data) +} + +func muxBindRoundtrip(client *Client, server *Server, bindAddr, targetURL string, data []byte) (err error) { + cn, err := client.Dial(server.Addr().String()) + if err != nil { + return err + } + + conn, err := client.Handshake(cn, + AddrHandshakeOption(server.Addr().String()), + UserHandshakeOption(url.UserPassword("admin", "123456")), + ) + if err != nil { + cn.Close() + return err + } + defer conn.Close() + + cc, err := net.Dial("tcp", bindAddr) + if err != nil { + return + } + defer cc.Close() + + conn, err = client.Connect(conn, "") + if err != nil { + return + } + + u, err := url.Parse(targetURL) + if err != nil { + return + } + hc, err := net.Dial("tcp", u.Host) + if err != nil { + return + } + defer hc.Close() + + go transport(hc, conn) + + return httpRoundtrip(cc, targetURL, data) +} + +func TestSOCKS5MuxBind(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + if err := socks5MuxBindRoundtrip(t, httpSrv.URL, sendData); err != nil { + t.Errorf("got error: %v", err) + } +} + +func BenchmarkSOCKS5MuxBind(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + l, err := net.Listen("tcp", "") + if err != nil { + b.Error(err) + } + bindAddr := l.Addr().String() + l.Close() + + client := &Client{ + Connector: Socks5MuxBindConnector(), + Transporter: SOCKS5MuxBindTransporter(bindAddr), + } + + server := &Server{ + Handler: SOCKS5Handler(UsersHandlerOption(url.UserPassword("admin", "123456"))), + Listener: ln, + } + + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := muxBindRoundtrip(client, server, bindAddr, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func socks5UDPRoundtrip(t *testing.T, host string, data []byte) (err error) { + ln, err := TCPListener("") + if err != nil { + return + } + + client := &Client{ + Connector: SOCKS5UDPConnector(url.UserPassword("admin", "123456")), + Transporter: TCPTransporter(), + } + + server := &Server{ + Handler: SOCKS5Handler(UsersHandlerOption(url.UserPassword("admin", "123456"))), + Listener: ln, + } + go server.Run() + defer server.Close() + + return udpRoundtrip(t, client, server, host, data) +} + +func TestSOCKS5UDP(t *testing.T) { + udpSrv := newUDPTestServer(udpTestHandler) + udpSrv.Start() + defer udpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + if err := socks5UDPRoundtrip(t, udpSrv.Addr(), sendData); err != nil { + t.Errorf("got error: %v", err) + } +} + +// TODO: fix a probability of timeout. +func BenchmarkSOCKS5UDP(b *testing.B) { + udpSrv := newUDPTestServer(udpTestHandler) + udpSrv.Start() + defer udpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: SOCKS5UDPConnector(url.UserPassword("admin", "123456")), + Transporter: TCPTransporter(), + } + + server := &Server{ + Handler: SOCKS5Handler(UsersHandlerOption(url.UserPassword("admin", "123456"))), + Listener: ln, + } + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := udpRoundtrip(b, client, server, udpSrv.Addr(), sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkSOCKS5UDPSingleConn(b *testing.B) { + udpSrv := newUDPTestServer(udpTestHandler) + udpSrv.Start() + defer udpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: SOCKS5UDPConnector(url.UserPassword("admin", "123456")), + Transporter: TCPTransporter(), + } + + server := &Server{ + Handler: SOCKS5Handler(UsersHandlerOption(url.UserPassword("admin", "123456"))), + Listener: ln, + } + go server.Run() + defer server.Close() + + conn, err := proxyConn(client, server) + if err != nil { + b.Error(err) + } + defer conn.Close() + + conn, err = client.Connect(conn, udpSrv.Addr()) + if err != nil { + b.Error(err) + } + + roundtrip := func(conn net.Conn, data []byte) error { + conn.SetDeadline(time.Now().Add(1 * time.Second)) + defer conn.SetDeadline(time.Time{}) + + if _, err = conn.Write(data); err != nil { + return err + } + + recv := make([]byte, len(data)) + if _, err = conn.Read(recv); err != nil { + return err + } + + if !bytes.Equal(data, recv) { + return fmt.Errorf("data not equal") + } + return nil + } + + for i := 0; i < b.N; i++ { + if err := roundtrip(conn, sendData); err != nil { + b.Error(err) + } + } +} + +func socks5UDPTunRoundtrip(t *testing.T, host string, data []byte) (err error) { + ln, err := TCPListener("") + if err != nil { + return + } + + client := &Client{ + Connector: SOCKS5UDPTunConnector(url.UserPassword("admin", "123456")), + Transporter: TCPTransporter(), + } + + server := &Server{ + Handler: SOCKS5Handler(UsersHandlerOption(url.UserPassword("admin", "123456"))), + Listener: ln, + } + go server.Run() + defer server.Close() + + return udpRoundtrip(t, client, server, host, data) +} + +func TestSOCKS5UDPTun(t *testing.T) { + udpSrv := newUDPTestServer(udpTestHandler) + udpSrv.Start() + defer udpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + if err := socks5UDPTunRoundtrip(t, udpSrv.Addr(), sendData); err != nil { + t.Errorf("got error: %v", err) + } +} + +func BenchmarkSOCKS5UDPTun(b *testing.B) { + udpSrv := newUDPTestServer(udpTestHandler) + udpSrv.Start() + defer udpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: SOCKS5UDPTunConnector(url.UserPassword("admin", "123456")), + Transporter: TCPTransporter(), + } + + server := &Server{ + Handler: SOCKS5Handler(UsersHandlerOption(url.UserPassword("admin", "123456"))), + Listener: ln, + } + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := udpRoundtrip(b, client, server, udpSrv.Addr(), sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkSOCKS5UDPTunSingleConn(b *testing.B) { + udpSrv := newUDPTestServer(udpTestHandler) + udpSrv.Start() + defer udpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: SOCKS5UDPTunConnector(url.UserPassword("admin", "123456")), + Transporter: TCPTransporter(), + } + + server := &Server{ + Handler: SOCKS5Handler(UsersHandlerOption(url.UserPassword("admin", "123456"))), + Listener: ln, + } + go server.Run() + defer server.Close() + + conn, err := proxyConn(client, server) + if err != nil { + b.Error(err) + } + defer conn.Close() + + conn, err = client.Connect(conn, udpSrv.Addr()) + if err != nil { + b.Error(err) + } + + roundtrip := func(conn net.Conn, data []byte) error { + conn.SetDeadline(time.Now().Add(1 * time.Second)) + defer conn.SetDeadline(time.Time{}) + + if _, err = conn.Write(data); err != nil { + return err + } + + recv := make([]byte, len(data)) + if _, err = conn.Read(recv); err != nil { + return err + } + + if !bytes.Equal(data, recv) { + return fmt.Errorf("data not equal") + } + return nil + } + + for i := 0; i < b.N; i++ { + if err := roundtrip(conn, sendData); err != nil { + b.Error(err) + } + } +} diff --git a/ss.go b/ss.go new file mode 100644 index 0000000..052cf19 --- /dev/null +++ b/ss.go @@ -0,0 +1,647 @@ +package gost + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "io" + "net" + "net/url" + "time" + + "github.com/go-gost/gosocks5" + "github.com/go-log/log" + "github.com/shadowsocks/go-shadowsocks2/core" + ss "github.com/shadowsocks/shadowsocks-go/shadowsocks" +) + +const ( + maxSocksAddrLen = 259 +) + +var ( + _ net.Conn = (*shadowConn)(nil) + _ net.PacketConn = (*shadowUDPPacketConn)(nil) +) + +type shadowConnector struct { + cipher core.Cipher +} + +// ShadowConnector creates a Connector for shadowsocks proxy client. +// It accepts an optional cipher info for shadowsocks data encryption/decryption. +func ShadowConnector(info *url.Userinfo) Connector { + return &shadowConnector{ + cipher: initShadowCipher(info), + } +} + +func (c *shadowConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + +func (c *shadowConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + return nil, fmt.Errorf("%s unsupported", network) + } + + opts := &ConnectOptions{} + for _, option := range options { + option(opts) + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = ConnectTimeout + } + + socksAddr, err := gosocks5.NewAddr(address) + if err != nil { + return nil, err + } + rawaddr := sPool.Get().([]byte) + defer sPool.Put(rawaddr) + + n, err := socksAddr.Encode(rawaddr) + if err != nil { + return nil, err + } + + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + + if c.cipher != nil { + conn = c.cipher.StreamConn(conn) + } + + sc := &shadowConn{ + Conn: conn, + } + + // write the addr at once. + if opts.NoDelay { + if _, err := sc.Write(rawaddr[:n]); err != nil { + return nil, err + } + } else { + sc.wbuf.Write(rawaddr[:n]) // cache the header + } + + return sc, nil +} + +type shadowHandler struct { + cipher core.Cipher + options *HandlerOptions +} + +// ShadowHandler creates a server Handler for shadowsocks proxy server. +func ShadowHandler(opts ...HandlerOption) Handler { + h := &shadowHandler{} + h.Init(opts...) + + return h +} + +func (h *shadowHandler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} + } + + for _, opt := range options { + opt(h.options) + } + if len(h.options.Users) > 0 { + h.cipher = initShadowCipher(h.options.Users[0]) + } +} + +func (h *shadowHandler) Handle(conn net.Conn) { + defer conn.Close() + + if h.cipher != nil { + conn = &shadowConn{ + Conn: h.cipher.StreamConn(conn), + } + } + + conn.SetReadDeadline(time.Now().Add(ReadTimeout)) + + addr, err := readSocksAddr(conn) + if err != nil { + log.Logf("[ss] %s -> %s : %s", + conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + + conn.SetReadDeadline(time.Time{}) + + host := addr.String() + log.Logf("[ss] %s -> %s", + conn.RemoteAddr(), host) + + if !Can("tcp", host, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[ss] %s - %s : Unauthorized to tcp connect to %s", + conn.RemoteAddr(), conn.LocalAddr(), host) + return + } + + if h.options.Bypass.Contains(host) { + log.Logf("[ss] %s - %s : Bypass %s", + conn.RemoteAddr(), conn.LocalAddr(), host) + return + } + + retries := 1 + if h.options.Chain != nil && h.options.Chain.Retries > 0 { + retries = h.options.Chain.Retries + } + if h.options.Retries > 0 { + retries = h.options.Retries + } + + var cc net.Conn + var route *Chain + for i := 0; i < retries; i++ { + route, err = h.options.Chain.selectRouteFor(host) + if err != nil { + log.Logf("[ss] %s -> %s : %s", + conn.RemoteAddr(), conn.LocalAddr(), err) + continue + } + + buf := bytes.Buffer{} + fmt.Fprintf(&buf, "%s -> %s -> ", + conn.RemoteAddr(), h.options.Node.String()) + for _, nd := range route.route { + fmt.Fprintf(&buf, "%d@%s -> ", nd.ID, nd.String()) + } + fmt.Fprintf(&buf, "%s", host) + log.Log("[route]", buf.String()) + + cc, err = route.Dial(host, + TimeoutChainOption(h.options.Timeout), + HostsChainOption(h.options.Hosts), + ResolverChainOption(h.options.Resolver), + ) + if err == nil { + break + } + log.Logf("[ss] %s -> %s : %s", + conn.RemoteAddr(), conn.LocalAddr(), err) + } + + if err != nil { + return + } + defer cc.Close() + + log.Logf("[ss] %s <-> %s", conn.RemoteAddr(), host) + transport(conn, cc) + log.Logf("[ss] %s >-< %s", conn.RemoteAddr(), host) +} + +type shadowUDPConnector struct { + cipher core.Cipher +} + +// ShadowUDPConnector creates a Connector for shadowsocks UDP client. +// It accepts an optional cipher info for shadowsocks data encryption/decryption. +func ShadowUDPConnector(info *url.Userinfo) Connector { + return &shadowUDPConnector{ + cipher: initShadowCipher(info), + } +} + +func (c *shadowUDPConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "udp", address, options...) +} + +func (c *shadowUDPConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "tcp", "tcp4", "tcp6": + return nil, fmt.Errorf("%s unsupported", network) + } + + opts := &ConnectOptions{} + for _, option := range options { + option(opts) + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = ConnectTimeout + } + + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + + taddr, _ := net.ResolveUDPAddr(network, address) + if taddr == nil { + taddr = &net.UDPAddr{} + } + + pc, ok := conn.(net.PacketConn) + if ok { + if c.cipher != nil { + pc = c.cipher.PacketConn(pc) + } + + return &shadowUDPPacketConn{ + PacketConn: pc, + raddr: conn.RemoteAddr(), + taddr: taddr, + }, nil + } + + if c.cipher != nil { + conn = &shadowConn{ + Conn: c.cipher.StreamConn(conn), + } + } + + return &socks5UDPTunnelConn{ + Conn: conn, + taddr: taddr, + }, nil +} + +type shadowUDPHandler struct { + cipher core.Cipher + options *HandlerOptions +} + +// ShadowUDPHandler creates a server Handler for shadowsocks UDP relay server. +func ShadowUDPHandler(opts ...HandlerOption) Handler { + h := &shadowUDPHandler{} + h.Init(opts...) + + return h +} + +func (h *shadowUDPHandler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} + } + for _, opt := range options { + opt(h.options) + } + if len(h.options.Users) > 0 { + h.cipher = initShadowCipher(h.options.Users[0]) + } +} + +func (h *shadowUDPHandler) Handle(conn net.Conn) { + defer conn.Close() + + var cc net.PacketConn + c, err := h.options.Chain.DialContext(context.Background(), "udp", "") + if err != nil { + log.Logf("[ssu] %s: %s", conn.LocalAddr(), err) + return + } + var ok bool + cc, ok = c.(net.PacketConn) + if !ok { + log.Logf("[ssu] %s: not a packet connection", conn.LocalAddr()) + return + } + + defer cc.Close() + + pc, ok := conn.(net.PacketConn) + if ok { + if h.cipher != nil { + pc = h.cipher.PacketConn(pc) + } + log.Logf("[ssu] %s <-> %s", conn.RemoteAddr(), conn.LocalAddr()) + h.transportPacket(pc, cc) + log.Logf("[ssu] %s >-< %s", conn.RemoteAddr(), conn.LocalAddr()) + return + } + + if h.cipher != nil { + conn = &shadowConn{ + Conn: h.cipher.StreamConn(conn), + } + } + + log.Logf("[ssu] %s <-> %s", conn.RemoteAddr(), conn.LocalAddr()) + h.transportUDP(conn, cc) + log.Logf("[ssu] %s >-< %s", conn.RemoteAddr(), conn.LocalAddr()) +} + +func (h *shadowUDPHandler) transportPacket(conn, cc net.PacketConn) (err error) { + errc := make(chan error, 1) + var clientAddr net.Addr + + go func() { + for { + err := func() error { + b := mPool.Get().([]byte) + defer mPool.Put(b) + + n, addr, err := conn.ReadFrom(b) + if err != nil { + return err + } + if clientAddr == nil { + clientAddr = addr + } + + r := bytes.NewBuffer(b[:n]) + saddr, err := readSocksAddr(r) + if err != nil { + return err + } + taddr, err := net.ResolveUDPAddr("udp", saddr.String()) + if err != nil { + return err + } + if Debug { + log.Logf("[ssu] %s >>> %s length: %d", addr, taddr, r.Len()) + } + _, err = cc.WriteTo(r.Bytes(), taddr) + return err + }() + + if err != nil { + errc <- err + return + } + } + }() + + go func() { + for { + err := func() error { + b := mPool.Get().([]byte) + defer mPool.Put(b) + + n, addr, err := cc.ReadFrom(b) + if err != nil { + return err + } + if clientAddr == nil { + return nil + } + + if Debug { + log.Logf("[ssu] %s <<< %s length: %d", clientAddr, addr, n) + } + + dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, toSocksAddr(addr)), b[:n]) + buf := bytes.Buffer{} + if err = dgram.Write(&buf); err != nil { + return err + } + _, err = conn.WriteTo(buf.Bytes()[3:], clientAddr) + return err + }() + + if err != nil { + errc <- err + return + } + } + }() + + select { + case err = <-errc: + } + + return +} + +func (h *shadowUDPHandler) transportUDP(conn net.Conn, cc net.PacketConn) error { + errc := make(chan error, 1) + + go func() { + for { + er := func() (err error) { + dgram, err := gosocks5.ReadUDPDatagram(conn) + if err != nil { + // log.Logf("[ssu] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + if Debug { + log.Logf("[ssu] %s >>> %s length: %d", + conn.RemoteAddr(), dgram.Header.Addr.String(), len(dgram.Data)) + } + addr, err := net.ResolveUDPAddr("udp", dgram.Header.Addr.String()) + if err != nil { + return + } + if h.options.Bypass.Contains(addr.String()) { + log.Log("[ssu] bypass", addr) + return // bypass + } + _, err = cc.WriteTo(dgram.Data, addr) + return + }() + + if er != nil { + errc <- er + return + } + } + }() + + go func() { + for { + er := func() (err error) { + b := mPool.Get().([]byte) + defer mPool.Put(b) + + n, addr, err := cc.ReadFrom(b) + if err != nil { + return + } + if Debug { + log.Logf("[ssu] %s <<< %s length: %d", conn.RemoteAddr(), addr, n) + } + if h.options.Bypass.Contains(addr.String()) { + log.Log("[ssu] bypass", addr) + return // bypass + } + dgram := gosocks5.NewUDPDatagram( + gosocks5.NewUDPHeader(uint16(n), 0, toSocksAddr(addr)), b[:n]) + buf := bytes.Buffer{} + dgram.Write(&buf) + _, err = conn.Write(buf.Bytes()) + return + }() + + if er != nil { + errc <- er + return + } + } + }() + + err := <-errc + if err != nil && err == io.EOF { + err = nil + } + return err +} + +// Due to in/out byte length is inconsistent of the shadowsocks.Conn.Write, +// we wrap around it to make io.Copy happy. +type shadowConn struct { + net.Conn + wbuf bytes.Buffer +} + +func (c *shadowConn) Write(b []byte) (n int, err error) { + n = len(b) // force byte length consistent + if c.wbuf.Len() > 0 { + c.wbuf.Write(b) // append the data to the cached header + _, err = c.Conn.Write(c.wbuf.Bytes()) + c.wbuf.Reset() + return + } + _, err = c.Conn.Write(b) + return +} + +type shadowUDPPacketConn struct { + net.PacketConn + raddr net.Addr + taddr net.Addr +} + +func (c *shadowUDPPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + buf := mPool.Get().([]byte) + defer mPool.Put(buf) + + buf[0] = 0 + buf[1] = 0 + buf[2] = 0 + + n, _, err = c.PacketConn.ReadFrom(buf[3:]) + if err != nil { + return + } + + dgram, err := gosocks5.ReadUDPDatagram(bytes.NewReader(buf[:n+3])) + if err != nil { + return + } + n = copy(b, dgram.Data) + addr, err = net.ResolveUDPAddr("udp", dgram.Header.Addr.String()) + + return + +} + +func (c *shadowUDPPacketConn) Read(b []byte) (n int, err error) { + n, _, err = c.ReadFrom(b) + return +} + +func (c *shadowUDPPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + sa, err := gosocks5.NewAddr(addr.String()) + if err != nil { + return + } + var rawaddr [maxSocksAddrLen]byte + nn, err := sa.Encode(rawaddr[:]) + if err != nil { + return + } + + buf := mPool.Get().([]byte) + defer mPool.Put(buf) + + copy(buf, rawaddr[:nn]) + n = copy(buf[nn:], b) + _, err = c.PacketConn.WriteTo(buf[:n+nn], c.raddr) + + return +} + +func (c *shadowUDPPacketConn) Write(b []byte) (n int, err error) { + return c.WriteTo(b, c.taddr) +} + +func (c *shadowUDPPacketConn) RemoteAddr() net.Addr { + return c.raddr +} + +type shadowCipher struct { + cipher *ss.Cipher +} + +func (c *shadowCipher) StreamConn(conn net.Conn) net.Conn { + return ss.NewConn(conn, c.cipher.Copy()) +} + +func (c *shadowCipher) PacketConn(conn net.PacketConn) net.PacketConn { + return ss.NewSecurePacketConn(conn, c.cipher.Copy()) +} + +func initShadowCipher(info *url.Userinfo) (cipher core.Cipher) { + var method, password string + if info != nil { + method = info.Username() + password, _ = info.Password() + } + + if method == "" || password == "" { + return + } + + cp, _ := ss.NewCipher(method, password) + if cp != nil { + cipher = &shadowCipher{cipher: cp} + } + if cipher == nil { + var err error + cipher, err = core.PickCipher(method, nil, password) + if err != nil { + log.Logf("[ss] %s", err) + return + } + } + return +} + +func readSocksAddr(r io.Reader) (*gosocks5.Addr, error) { + addr := &gosocks5.Addr{} + b := sPool.Get().([]byte) + defer sPool.Put(b) + + _, err := io.ReadFull(r, b[:1]) + if err != nil { + return nil, err + } + addr.Type = b[0] + + switch addr.Type { + case gosocks5.AddrIPv4: + _, err = io.ReadFull(r, b[:net.IPv4len]) + addr.Host = net.IP(b[0:net.IPv4len]).String() + case gosocks5.AddrIPv6: + _, err = io.ReadFull(r, b[:net.IPv6len]) + addr.Host = net.IP(b[0:net.IPv6len]).String() + case gosocks5.AddrDomain: + if _, err = io.ReadFull(r, b[:1]); err != nil { + return nil, err + } + addrlen := int(b[0]) + _, err = io.ReadFull(r, b[:addrlen]) + addr.Host = string(b[:addrlen]) + default: + return nil, gosocks5.ErrBadAddrType + } + if err != nil { + return nil, err + } + + _, err = io.ReadFull(r, b[:2]) + addr.Port = binary.BigEndian.Uint16(b[:2]) + return addr, err +} diff --git a/ss_test.go b/ss_test.go new file mode 100644 index 0000000..62484d0 --- /dev/null +++ b/ss_test.go @@ -0,0 +1,567 @@ +package gost + +import ( + "bytes" + "crypto/rand" + "fmt" + "net/http/httptest" + "net/url" + "testing" + "time" +) + +func init() { + // ss.Debug = true +} + +var ssTests = []struct { + clientCipher *url.Userinfo + serverCipher *url.Userinfo + pass bool +}{ + {nil, nil, true}, + {&url.Userinfo{}, &url.Userinfo{}, true}, + {url.User("abc"), url.User("abc"), true}, + {url.UserPassword("abc", "def"), url.UserPassword("abc", "def"), true}, + + {url.User("aes-128-cfb"), url.User("aes-128-cfb"), true}, + {url.User("aes-128-cfb"), url.UserPassword("aes-128-cfb", "123456"), false}, + {url.UserPassword("aes-128-cfb", "123456"), url.User("aes-128-cfb"), false}, + {url.UserPassword("aes-128-cfb", "123456"), url.UserPassword("aes-128-cfb", "abc"), false}, + {url.UserPassword("aes-128-cfb", "123456"), url.UserPassword("aes-128-cfb", "123456"), true}, + + {url.User("aes-192-cfb"), url.User("aes-192-cfb"), true}, + {url.User("aes-192-cfb"), url.UserPassword("aes-192-cfb", "123456"), false}, + {url.UserPassword("aes-192-cfb", "123456"), url.User("aes-192-cfb"), false}, + {url.UserPassword("aes-192-cfb", "123456"), url.UserPassword("aes-192-cfb", "abc"), false}, + {url.UserPassword("aes-192-cfb", "123456"), url.UserPassword("aes-192-cfb", "123456"), true}, + + {url.User("aes-256-cfb"), url.User("aes-256-cfb"), true}, + {url.User("aes-256-cfb"), url.UserPassword("aes-256-cfb", "123456"), false}, + {url.UserPassword("aes-256-cfb", "123456"), url.User("aes-256-cfb"), false}, + {url.UserPassword("aes-256-cfb", "123456"), url.UserPassword("aes-256-cfb", "abc"), false}, + {url.UserPassword("aes-256-cfb", "123456"), url.UserPassword("aes-256-cfb", "123456"), true}, + + {url.User("aes-128-ctr"), url.User("aes-128-ctr"), true}, + {url.User("aes-128-ctr"), url.UserPassword("aes-128-ctr", "123456"), false}, + {url.UserPassword("aes-128-ctr", "123456"), url.User("aes-128-ctr"), false}, + {url.UserPassword("aes-128-ctr", "123456"), url.UserPassword("aes-128-ctr", "abc"), false}, + {url.UserPassword("aes-128-ctr", "123456"), url.UserPassword("aes-128-ctr", "123456"), true}, + + {url.User("aes-192-ctr"), url.User("aes-192-ctr"), true}, + {url.User("aes-192-ctr"), url.UserPassword("aes-192-ctr", "123456"), false}, + {url.UserPassword("aes-192-ctr", "123456"), url.User("aes-192-ctr"), false}, + {url.UserPassword("aes-192-ctr", "123456"), url.UserPassword("aes-192-ctr", "abc"), false}, + {url.UserPassword("aes-192-ctr", "123456"), url.UserPassword("aes-192-ctr", "123456"), true}, + + {url.User("aes-256-ctr"), url.User("aes-256-ctr"), true}, + {url.User("aes-256-ctr"), url.UserPassword("aes-256-ctr", "123456"), false}, + {url.UserPassword("aes-256-ctr", "123456"), url.User("aes-256-ctr"), false}, + {url.UserPassword("aes-256-ctr", "123456"), url.UserPassword("aes-256-ctr", "abc"), false}, + {url.UserPassword("aes-256-ctr", "123456"), url.UserPassword("aes-256-ctr", "123456"), true}, + + {url.User("des-cfb"), url.User("des-cfb"), true}, + {url.User("des-cfb"), url.UserPassword("des-cfb", "123456"), false}, + {url.UserPassword("des-cfb", "123456"), url.User("des-cfb"), false}, + {url.UserPassword("des-cfb", "123456"), url.UserPassword("des-cfb", "abc"), false}, + {url.UserPassword("des-cfb", "123456"), url.UserPassword("des-cfb", "123456"), true}, + + {url.User("bf-cfb"), url.User("bf-cfb"), true}, + {url.User("bf-cfb"), url.UserPassword("bf-cfb", "123456"), false}, + {url.UserPassword("bf-cfb", "123456"), url.User("bf-cfb"), false}, + {url.UserPassword("bf-cfb", "123456"), url.UserPassword("bf-cfb", "abc"), false}, + {url.UserPassword("bf-cfb", "123456"), url.UserPassword("bf-cfb", "123456"), true}, + + {url.User("cast5-cfb"), url.User("cast5-cfb"), true}, + {url.User("cast5-cfb"), url.UserPassword("cast5-cfb", "123456"), false}, + {url.UserPassword("cast5-cfb", "123456"), url.User("cast5-cfb"), false}, + {url.UserPassword("cast5-cfb", "123456"), url.UserPassword("cast5-cfb", "abc"), false}, + {url.UserPassword("cast5-cfb", "123456"), url.UserPassword("cast5-cfb", "123456"), true}, + + {url.User("rc4-md5"), url.User("rc4-md5"), true}, + {url.User("rc4-md5"), url.UserPassword("rc4-md5", "123456"), false}, + {url.UserPassword("rc4-md5", "123456"), url.User("rc4-md5"), false}, + {url.UserPassword("rc4-md5", "123456"), url.UserPassword("rc4-md5", "abc"), false}, + {url.UserPassword("rc4-md5", "123456"), url.UserPassword("rc4-md5", "123456"), true}, + + {url.User("chacha20"), url.User("chacha20"), true}, + {url.User("chacha20"), url.UserPassword("chacha20", "123456"), false}, + {url.UserPassword("chacha20", "123456"), url.User("chacha20"), false}, + {url.UserPassword("chacha20", "123456"), url.UserPassword("chacha20", "abc"), false}, + {url.UserPassword("chacha20", "123456"), url.UserPassword("chacha20", "123456"), true}, + + {url.User("chacha20-ietf"), url.User("chacha20-ietf"), true}, + {url.User("chacha20-ietf"), url.UserPassword("chacha20-ietf", "123456"), false}, + {url.UserPassword("chacha20-ietf", "123456"), url.User("chacha20-ietf"), false}, + {url.UserPassword("chacha20-ietf", "123456"), url.UserPassword("chacha20-ietf", "abc"), false}, + {url.UserPassword("chacha20-ietf", "123456"), url.UserPassword("chacha20-ietf", "123456"), true}, + + {url.User("salsa20"), url.User("salsa20"), true}, + {url.User("salsa20"), url.UserPassword("salsa20", "123456"), false}, + {url.UserPassword("salsa20", "123456"), url.User("salsa20"), false}, + {url.UserPassword("salsa20", "123456"), url.UserPassword("salsa20", "abc"), false}, + {url.UserPassword("salsa20", "123456"), url.UserPassword("salsa20", "123456"), true}, + + {url.User("xchacha20"), url.User("xchacha20"), true}, + {url.User("xchacha20"), url.UserPassword("xchacha20", "123456"), false}, + {url.UserPassword("xchacha20", "123456"), url.User("xchacha20"), false}, + {url.UserPassword("xchacha20", "123456"), url.UserPassword("xchacha20", "abc"), false}, + {url.UserPassword("xchacha20", "123456"), url.UserPassword("xchacha20", "123456"), true}, + + {url.User("CHACHA20-IETF-POLY1305"), url.User("CHACHA20-IETF-POLY1305"), true}, + {url.User("CHACHA20-IETF-POLY1305"), url.UserPassword("CHACHA20-IETF-POLY1305", "123456"), false}, + {url.UserPassword("CHACHA20-IETF-POLY1305", "123456"), url.User("CHACHA20-IETF-POLY1305"), false}, + {url.UserPassword("CHACHA20-IETF-POLY1305", "123456"), url.UserPassword("CHACHA20-IETF-POLY1305", "abc"), false}, + {url.UserPassword("CHACHA20-IETF-POLY1305", "123456"), url.UserPassword("CHACHA20-IETF-POLY1305", "123456"), true}, + + {url.User("AES-128-GCM"), url.User("AES-128-GCM"), true}, + {url.User("AES-128-GCM"), url.UserPassword("AES-128-GCM", "123456"), false}, + {url.UserPassword("AES-128-GCM", "123456"), url.User("AES-128-GCM"), false}, + {url.UserPassword("AES-128-GCM", "123456"), url.UserPassword("AES-128-GCM", "abc"), false}, + {url.UserPassword("AES-128-GCM", "123456"), url.UserPassword("AES-128-GCM", "123456"), true}, + + {url.User("AES-192-GCM"), url.User("AES-192-GCM"), true}, + {url.User("AES-192-GCM"), url.UserPassword("AES-192-GCM", "123456"), false}, + {url.UserPassword("AES-192-GCM", "123456"), url.User("AES-192-GCM"), false}, + {url.UserPassword("AES-192-GCM", "123456"), url.UserPassword("AES-192-GCM", "abc"), false}, + {url.UserPassword("AES-192-GCM", "123456"), url.UserPassword("AES-192-GCM", "123456"), true}, + + {url.User("AES-256-GCM"), url.User("AES-256-GCM"), true}, + {url.User("AES-256-GCM"), url.UserPassword("AES-256-GCM", "123456"), false}, + {url.UserPassword("AES-256-GCM", "123456"), url.User("AES-256-GCM"), false}, + {url.UserPassword("AES-256-GCM", "123456"), url.UserPassword("AES-256-GCM", "abc"), false}, + {url.UserPassword("AES-256-GCM", "123456"), url.UserPassword("AES-256-GCM", "123456"), true}, +} + +var ssProxyTests = []struct { + clientCipher *url.Userinfo + serverCipher *url.Userinfo + pass bool +}{ + {nil, nil, true}, + {&url.Userinfo{}, &url.Userinfo{}, true}, + {url.User("abc"), url.User("abc"), true}, + {url.UserPassword("abc", "def"), url.UserPassword("abc", "def"), true}, + + {url.User("aes-128-cfb"), url.User("aes-128-cfb"), true}, + {url.User("aes-128-cfb"), url.UserPassword("aes-128-cfb", "123456"), false}, + {url.UserPassword("aes-128-cfb", "123456"), url.User("aes-128-cfb"), false}, + {url.UserPassword("aes-128-cfb", "123456"), url.UserPassword("aes-128-cfb", "123456"), true}, + + {url.User("CHACHA20-IETF-POLY1305"), url.User("CHACHA20-IETF-POLY1305"), true}, + {url.User("CHACHA20-IETF-POLY1305"), url.UserPassword("CHACHA20-IETF-POLY1305", "123456"), false}, + {url.UserPassword("CHACHA20-IETF-POLY1305", "123456"), url.User("CHACHA20-IETF-POLY1305"), false}, + {url.UserPassword("CHACHA20-IETF-POLY1305", "123456"), url.UserPassword("CHACHA20-IETF-POLY1305", "123456"), true}, +} + +func ssProxyRoundtrip(targetURL string, data []byte, clientInfo *url.Userinfo, serverInfo *url.Userinfo) error { + ln, err := TCPListener("") + if err != nil { + return err + } + + client := &Client{ + Connector: ShadowConnector(clientInfo), + Transporter: TCPTransporter(), + } + + server := &Server{ + Handler: ShadowHandler(UsersHandlerOption(serverInfo)), + Listener: ln, + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestShadowTCP(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range ssTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := ssProxyRoundtrip(httpSrv.URL, sendData, + tc.clientCipher, + tc.serverCipher, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} + +func BenchmarkSSProxy_AES256(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: ShadowConnector(url.UserPassword("aes-256-cfb", "123456")), + Transporter: TCPTransporter(), + } + + server := &Server{ + Handler: ShadowHandler(UsersHandlerOption(url.UserPassword("aes-256-cfb", "123456"))), + Listener: ln, + } + + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkSSProxy_Chacha20(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: ShadowConnector(url.UserPassword("chacha20", "123456")), + Transporter: TCPTransporter(), + } + + server := &Server{ + Handler: ShadowHandler(UsersHandlerOption(url.UserPassword("chacha20", "123456"))), + Listener: ln, + } + + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkSSProxy_Chacha20_ietf(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: ShadowConnector(url.UserPassword("chacha20-ietf", "123456")), + Transporter: TCPTransporter(), + } + + server := &Server{ + Handler: ShadowHandler(UsersHandlerOption(url.UserPassword("chacha20-ietf", "123456"))), + Listener: ln, + } + + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkSSProxyParallel(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: ShadowConnector(url.UserPassword("chacha20-ietf", "123456")), + Transporter: TCPTransporter(), + } + + server := &Server{ + Handler: ShadowHandler(UsersHandlerOption(url.UserPassword("chacha20-ietf", "123456"))), + Listener: ln, + } + + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } + }) +} + +var ssuTests = []struct { + clientCipher *url.Userinfo + serverCipher *url.Userinfo + pass bool +}{ + {nil, nil, true}, + {&url.Userinfo{}, &url.Userinfo{}, true}, + {url.User("abc"), url.User("abc"), true}, + {url.UserPassword("abc", "def"), url.UserPassword("abc", "def"), true}, + + {url.User("aes-128-cfb"), url.User("aes-128-cfb"), true}, + {url.User("aes-128-cfb"), url.UserPassword("aes-128-cfb", "123456"), false}, + {url.UserPassword("aes-128-cfb", "123456"), url.User("aes-128-cfb"), false}, + {url.UserPassword("aes-128-cfb", "123456"), url.UserPassword("aes-128-cfb", "abc"), false}, + {url.UserPassword("aes-128-cfb", "123456"), url.UserPassword("aes-128-cfb", "123456"), true}, + + {url.User("aes-192-cfb"), url.User("aes-192-cfb"), true}, + {url.User("aes-192-cfb"), url.UserPassword("aes-192-cfb", "123456"), false}, + {url.UserPassword("aes-192-cfb", "123456"), url.User("aes-192-cfb"), false}, + {url.UserPassword("aes-192-cfb", "123456"), url.UserPassword("aes-192-cfb", "abc"), false}, + {url.UserPassword("aes-192-cfb", "123456"), url.UserPassword("aes-192-cfb", "123456"), true}, + + {url.User("aes-256-cfb"), url.User("aes-256-cfb"), true}, + {url.User("aes-256-cfb"), url.UserPassword("aes-256-cfb", "123456"), false}, + {url.UserPassword("aes-256-cfb", "123456"), url.User("aes-256-cfb"), false}, + {url.UserPassword("aes-256-cfb", "123456"), url.UserPassword("aes-256-cfb", "abc"), false}, + {url.UserPassword("aes-256-cfb", "123456"), url.UserPassword("aes-256-cfb", "123456"), true}, + + {url.User("aes-128-ctr"), url.User("aes-128-ctr"), true}, + {url.User("aes-128-ctr"), url.UserPassword("aes-128-ctr", "123456"), false}, + {url.UserPassword("aes-128-ctr", "123456"), url.User("aes-128-ctr"), false}, + {url.UserPassword("aes-128-ctr", "123456"), url.UserPassword("aes-128-ctr", "abc"), false}, + {url.UserPassword("aes-128-ctr", "123456"), url.UserPassword("aes-128-ctr", "123456"), true}, + + {url.User("aes-192-ctr"), url.User("aes-192-ctr"), true}, + {url.User("aes-192-ctr"), url.UserPassword("aes-192-ctr", "123456"), false}, + {url.UserPassword("aes-192-ctr", "123456"), url.User("aes-192-ctr"), false}, + {url.UserPassword("aes-192-ctr", "123456"), url.UserPassword("aes-192-ctr", "abc"), false}, + {url.UserPassword("aes-192-ctr", "123456"), url.UserPassword("aes-192-ctr", "123456"), true}, + + {url.User("aes-256-ctr"), url.User("aes-256-ctr"), true}, + {url.User("aes-256-ctr"), url.UserPassword("aes-256-ctr", "123456"), false}, + {url.UserPassword("aes-256-ctr", "123456"), url.User("aes-256-ctr"), false}, + {url.UserPassword("aes-256-ctr", "123456"), url.UserPassword("aes-256-ctr", "abc"), false}, + {url.UserPassword("aes-256-ctr", "123456"), url.UserPassword("aes-256-ctr", "123456"), true}, + + {url.User("des-cfb"), url.User("des-cfb"), true}, + {url.User("des-cfb"), url.UserPassword("des-cfb", "123456"), false}, + {url.UserPassword("des-cfb", "123456"), url.User("des-cfb"), false}, + {url.UserPassword("des-cfb", "123456"), url.UserPassword("des-cfb", "abc"), false}, + {url.UserPassword("des-cfb", "123456"), url.UserPassword("des-cfb", "123456"), true}, + + {url.User("bf-cfb"), url.User("bf-cfb"), true}, + {url.User("bf-cfb"), url.UserPassword("bf-cfb", "123456"), false}, + {url.UserPassword("bf-cfb", "123456"), url.User("bf-cfb"), false}, + {url.UserPassword("bf-cfb", "123456"), url.UserPassword("bf-cfb", "abc"), false}, + {url.UserPassword("bf-cfb", "123456"), url.UserPassword("bf-cfb", "123456"), true}, + + {url.User("cast5-cfb"), url.User("cast5-cfb"), true}, + {url.User("cast5-cfb"), url.UserPassword("cast5-cfb", "123456"), false}, + {url.UserPassword("cast5-cfb", "123456"), url.User("cast5-cfb"), false}, + {url.UserPassword("cast5-cfb", "123456"), url.UserPassword("cast5-cfb", "abc"), false}, + {url.UserPassword("cast5-cfb", "123456"), url.UserPassword("cast5-cfb", "123456"), true}, + + {url.User("rc4-md5"), url.User("rc4-md5"), true}, + {url.User("rc4-md5"), url.UserPassword("rc4-md5", "123456"), false}, + {url.UserPassword("rc4-md5", "123456"), url.User("rc4-md5"), false}, + {url.UserPassword("rc4-md5", "123456"), url.UserPassword("rc4-md5", "abc"), false}, + {url.UserPassword("rc4-md5", "123456"), url.UserPassword("rc4-md5", "123456"), true}, + + {url.User("chacha20"), url.User("chacha20"), true}, + {url.User("chacha20"), url.UserPassword("chacha20", "123456"), false}, + {url.UserPassword("chacha20", "123456"), url.User("chacha20"), false}, + {url.UserPassword("chacha20", "123456"), url.UserPassword("chacha20", "abc"), false}, + {url.UserPassword("chacha20", "123456"), url.UserPassword("chacha20", "123456"), true}, + + {url.User("chacha20-ietf"), url.User("chacha20-ietf"), true}, + {url.User("chacha20-ietf"), url.UserPassword("chacha20-ietf", "123456"), false}, + {url.UserPassword("chacha20-ietf", "123456"), url.User("chacha20-ietf"), false}, + {url.UserPassword("chacha20-ietf", "123456"), url.UserPassword("chacha20-ietf", "abc"), false}, + {url.UserPassword("chacha20-ietf", "123456"), url.UserPassword("chacha20-ietf", "123456"), true}, + + {url.User("salsa20"), url.User("salsa20"), true}, + {url.User("salsa20"), url.UserPassword("salsa20", "123456"), false}, + {url.UserPassword("salsa20", "123456"), url.User("salsa20"), false}, + {url.UserPassword("salsa20", "123456"), url.UserPassword("salsa20", "abc"), false}, + {url.UserPassword("salsa20", "123456"), url.UserPassword("salsa20", "123456"), true}, + + {url.User("xchacha20"), url.User("xchacha20"), true}, + {url.User("xchacha20"), url.UserPassword("xchacha20", "123456"), false}, + {url.UserPassword("xchacha20", "123456"), url.User("xchacha20"), false}, + {url.UserPassword("xchacha20", "123456"), url.UserPassword("xchacha20", "abc"), false}, + {url.UserPassword("xchacha20", "123456"), url.UserPassword("xchacha20", "123456"), true}, + + {url.User("CHACHA20-IETF-POLY1305"), url.User("CHACHA20-IETF-POLY1305"), true}, + {url.User("CHACHA20-IETF-POLY1305"), url.UserPassword("CHACHA20-IETF-POLY1305", "123456"), false}, + {url.UserPassword("CHACHA20-IETF-POLY1305", "123456"), url.User("CHACHA20-IETF-POLY1305"), false}, + {url.UserPassword("CHACHA20-IETF-POLY1305", "123456"), url.UserPassword("CHACHA20-IETF-POLY1305", "abc"), false}, + {url.UserPassword("CHACHA20-IETF-POLY1305", "123456"), url.UserPassword("CHACHA20-IETF-POLY1305", "123456"), true}, + + {url.User("AES-128-GCM"), url.User("AES-128-GCM"), true}, + {url.User("AES-128-GCM"), url.UserPassword("AES-128-GCM", "123456"), false}, + {url.UserPassword("AES-128-GCM", "123456"), url.User("AES-128-GCM"), false}, + {url.UserPassword("AES-128-GCM", "123456"), url.UserPassword("AES-128-GCM", "abc"), false}, + {url.UserPassword("AES-128-GCM", "123456"), url.UserPassword("AES-128-GCM", "123456"), true}, + + {url.User("AES-192-GCM"), url.User("AES-192-GCM"), true}, + {url.User("AES-192-GCM"), url.UserPassword("AES-192-GCM", "123456"), false}, + {url.UserPassword("AES-192-GCM", "123456"), url.User("AES-192-GCM"), false}, + {url.UserPassword("AES-192-GCM", "123456"), url.UserPassword("AES-192-GCM", "abc"), false}, + {url.UserPassword("AES-192-GCM", "123456"), url.UserPassword("AES-192-GCM", "123456"), true}, + + {url.User("AES-256-GCM"), url.User("AES-256-GCM"), true}, + {url.User("AES-256-GCM"), url.UserPassword("AES-256-GCM", "123456"), false}, + {url.UserPassword("AES-256-GCM", "123456"), url.User("AES-256-GCM"), false}, + {url.UserPassword("AES-256-GCM", "123456"), url.UserPassword("AES-256-GCM", "abc"), false}, + {url.UserPassword("AES-256-GCM", "123456"), url.UserPassword("AES-256-GCM", "123456"), true}, +} + +func shadowUDPRoundtrip(t *testing.T, host string, data []byte, + clientInfo *url.Userinfo, serverInfo *url.Userinfo) error { + ln, err := UDPListener("localhost:0", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: ShadowUDPConnector(clientInfo), + Transporter: UDPTransporter(), + } + + server := &Server{ + Handler: ShadowUDPHandler( + UsersHandlerOption(serverInfo), + ), + Listener: ln, + } + + go server.Run() + defer server.Close() + + return udpRoundtrip(t, client, server, host, data) +} + +func TestShadowUDP(t *testing.T) { + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range ssuTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + udpSrv := newUDPTestServer(udpTestHandler) + udpSrv.Start() + defer udpSrv.Close() + + err := shadowUDPRoundtrip(t, udpSrv.Addr(), sendData, + tc.clientCipher, + tc.serverCipher, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} + +func BenchmarkShadowUDP(b *testing.B) { + udpSrv := newUDPTestServer(udpTestHandler) + udpSrv.Start() + defer udpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := UDPListener("localhost:0", nil) + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: ShadowUDPConnector(url.UserPassword("chacha20-ietf", "123456")), + Transporter: UDPTransporter(), + } + + server := &Server{ + Handler: ShadowUDPHandler( + UsersHandlerOption(url.UserPassword("chacha20-ietf", "123456")), + ), + Listener: ln, + } + + go server.Run() + defer server.Close() + + conn, err := proxyConn(client, server) + if err != nil { + b.Error(err) + } + defer conn.Close() + + conn, err = client.Connect(conn, udpSrv.Addr()) + if err != nil { + return + } + + for i := 0; i < b.N; i++ { + conn.SetDeadline(time.Now().Add(3 * time.Second)) + + if _, err = conn.Write(sendData); err != nil { + b.Error(err) + } + + recv := make([]byte, len(sendData)) + if _, err = conn.Read(recv); err != nil { + b.Error(err) + } + + conn.SetDeadline(time.Time{}) + + if !bytes.Equal(sendData, recv) { + b.Error("data not equal") + } + } +} diff --git a/ssh.go b/ssh.go new file mode 100644 index 0000000..f772503 --- /dev/null +++ b/ssh.go @@ -0,0 +1,982 @@ +package gost + +import ( + "context" + "crypto/tls" + "encoding/binary" + "errors" + "fmt" + "io/ioutil" + "net" + "strconv" + "strings" + "sync" + "time" + + "github.com/go-log/log" + "golang.org/x/crypto/ssh" +) + +// Applicable SSH Request types for Port Forwarding - RFC 4254 7.X +const ( + DirectForwardRequest = "direct-tcpip" // RFC 4254 7.2 + RemoteForwardRequest = "tcpip-forward" // RFC 4254 7.1 + ForwardedTCPReturnRequest = "forwarded-tcpip" // RFC 4254 7.2 + CancelRemoteForwardRequest = "cancel-tcpip-forward" // RFC 4254 7.1 + + GostSSHTunnelRequest = "gost-tunnel" // extended request type for ssh tunnel +) + +var ( + errSessionDead = errors.New("session is dead") +) + +// ParseSSHKeyFile parses ssh key file. +func ParseSSHKeyFile(fp string) (ssh.Signer, error) { + key, err := ioutil.ReadFile(fp) + if err != nil { + return nil, err + } + return ssh.ParsePrivateKey(key) +} + +// ParseSSHAuthorizedKeysFile parses ssh Authorized Keys file. +func ParseSSHAuthorizedKeysFile(fp string) (map[string]bool, error) { + authorizedKeysBytes, err := ioutil.ReadFile(fp) + if err != nil { + return nil, err + } + authorizedKeysMap := make(map[string]bool) + for len(authorizedKeysBytes) > 0 { + pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) + if err != nil { + return nil, err + } + authorizedKeysMap[string(pubKey.Marshal())] = true + authorizedKeysBytes = rest + } + + return authorizedKeysMap, nil +} + +type sshDirectForwardConnector struct { +} + +// SSHDirectForwardConnector creates a Connector for SSH TCP direct port forwarding. +func SSHDirectForwardConnector() Connector { + return &sshDirectForwardConnector{} +} + +func (c *sshDirectForwardConnector) Connect(conn net.Conn, raddr string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", raddr, options...) +} + +func (c *sshDirectForwardConnector) ConnectContext(ctx context.Context, conn net.Conn, network, raddr string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + return nil, fmt.Errorf("%s unsupported", network) + } + + opts := &ConnectOptions{} + for _, option := range options { + option(opts) + } + + cc, ok := conn.(*sshNopConn) // TODO: this is an ugly type assertion, need to find a better solution. + if !ok { + return nil, errors.New("ssh: wrong connection type") + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = ConnectTimeout + } + + cc.session.conn.SetDeadline(time.Now().Add(timeout)) + defer cc.session.conn.SetDeadline(time.Time{}) + + conn, err := cc.session.client.Dial("tcp", raddr) + if err != nil { + log.Logf("[ssh-tcp] %s -> %s : %s", cc.session.addr, raddr, err) + return nil, err + } + return conn, nil +} + +type sshRemoteForwardConnector struct { +} + +// SSHRemoteForwardConnector creates a Connector for SSH TCP remote port forwarding. +func SSHRemoteForwardConnector() Connector { + return &sshRemoteForwardConnector{} +} + +func (c *sshRemoteForwardConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { + return c.ConnectContext(context.Background(), conn, "tcp", address, options...) +} + +func (c *sshRemoteForwardConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { + switch network { + case "udp", "udp4", "udp6": + return nil, fmt.Errorf("%s unsupported", network) + } + + cc, ok := conn.(*sshNopConn) // TODO: this is an ugly type assertion, need to find a better solution. + if !ok { + return nil, errors.New("ssh: wrong connection type") + } + + cc.session.once.Do(func() { + go func() { + defer log.Log("ssh-rtcp: session is closed") + defer close(cc.session.connChan) + + if cc.session == nil || cc.session.client == nil { + return + } + if strings.HasPrefix(address, ":") { + address = "0.0.0.0" + address + } + ln, err := cc.session.client.Listen("tcp", address) + if err != nil { + return + } + log.Log("[ssh-rtcp] listening on", ln.Addr()) + + for { + rc, err := ln.Accept() + if err != nil { + log.Logf("[ssh-rtcp] %s <-> %s accpet : %s", ln.Addr(), address, err) + return + } + // log.Log("[ssh-rtcp] accept", rc.LocalAddr(), rc.RemoteAddr()) + select { + case cc.session.connChan <- rc: + default: + rc.Close() + log.Logf("[ssh-rtcp] %s - %s: connection queue is full", ln.Addr(), address) + } + } + }() + }) + + sc, ok := <-cc.session.connChan + if !ok { + return nil, errors.New("ssh-rtcp: connection is closed") + } + return sc, nil +} + +type sshForwardTransporter struct { + sessions map[string]*sshSession + sessionMutex sync.Mutex +} + +// SSHForwardTransporter creates a Transporter that is used by SSH port forwarding server. +func SSHForwardTransporter() Transporter { + return &sshForwardTransporter{ + sessions: make(map[string]*sshSession), + } +} + +func (tr *sshForwardTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { + opts := &DialOptions{} + for _, option := range options { + option(opts) + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + timeout := opts.Timeout + if timeout <= 0 { + timeout = DialTimeout + } + + session, ok := tr.sessions[addr] + if !ok || session.Closed() { + if opts.Chain == nil { + conn, err = net.DialTimeout("tcp", addr, timeout) + } else { + conn, err = opts.Chain.Dial(addr) + } + if err != nil { + return + } + session = &sshSession{ + addr: addr, + conn: conn, + } + tr.sessions[addr] = session + } + + return session.conn, nil +} + +func (tr *sshForwardTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = HandshakeTimeout + } + + config := ssh.ClientConfig{ + Timeout: timeout, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + if opts.User != nil { + config.User = opts.User.Username() + if password, _ := opts.User.Password(); password != "" { + config.Auth = []ssh.AuthMethod{ + ssh.Password(password), + } + } + } + if opts.SSHConfig != nil && opts.SSHConfig.Key != nil { + config.Auth = append(config.Auth, ssh.PublicKeys(opts.SSHConfig.Key)) + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + + session, ok := tr.sessions[opts.Addr] + if !ok || session.client == nil { + sshConn, chans, reqs, err := ssh.NewClientConn(conn, opts.Addr, &config) + if err != nil { + log.Log("ssh", err) + conn.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + + session = &sshSession{ + addr: opts.Addr, + conn: conn, + client: ssh.NewClient(sshConn, chans, reqs), + closed: make(chan struct{}), + deaded: make(chan struct{}), + connChan: make(chan net.Conn, 1024), + } + tr.sessions[opts.Addr] = session + go session.Ping(opts.Interval, opts.Timeout, opts.Retry) + go session.waitServer() + go session.waitClose() + } + if session.Closed() { + delete(tr.sessions, opts.Addr) + return nil, errSessionDead + } + + return &sshNopConn{session: session}, nil +} + +func (tr *sshForwardTransporter) Multiplex() bool { + return true +} + +type sshTunnelTransporter struct { + sessions map[string]*sshSession + sessionMutex sync.Mutex +} + +// SSHTunnelTransporter creates a Transporter that is used by SSH tunnel client. +func SSHTunnelTransporter() Transporter { + return &sshTunnelTransporter{ + sessions: make(map[string]*sshSession), + } +} + +func (tr *sshTunnelTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { + opts := &DialOptions{} + for _, option := range options { + option(opts) + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + timeout := opts.Timeout + if timeout <= 0 { + timeout = DialTimeout + } + + session, ok := tr.sessions[addr] + if !ok || session.Closed() { + if opts.Chain == nil { + conn, err = net.DialTimeout("tcp", addr, timeout) + } else { + conn, err = opts.Chain.Dial(addr) + } + if err != nil { + return + } + session = &sshSession{ + addr: addr, + conn: conn, + } + tr.sessions[addr] = session + } + + return session.conn, nil +} + +func (tr *sshTunnelTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = HandshakeTimeout + } + + config := ssh.ClientConfig{ + Timeout: timeout, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + if opts.User != nil { + config.User = opts.User.Username() + if password, _ := opts.User.Password(); password != "" { + config.Auth = []ssh.AuthMethod{ + ssh.Password(password), + } + } + } + if opts.SSHConfig != nil && opts.SSHConfig.Key != nil { + config.Auth = append(config.Auth, ssh.PublicKeys(opts.SSHConfig.Key)) + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + + session, ok := tr.sessions[opts.Addr] + if !ok || session.client == nil { + sshConn, chans, reqs, err := ssh.NewClientConn(conn, opts.Addr, &config) + if err != nil { + conn.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + + session = &sshSession{ + addr: opts.Addr, + conn: conn, + client: ssh.NewClient(sshConn, chans, reqs), + closed: make(chan struct{}), + deaded: make(chan struct{}), + } + tr.sessions[opts.Addr] = session + go session.Ping(opts.Interval, opts.Timeout, opts.Retry) + go session.waitServer() + go session.waitClose() + } + + if session.Closed() { + delete(tr.sessions, opts.Addr) + return nil, errSessionDead + } + + channel, reqs, err := session.client.OpenChannel(GostSSHTunnelRequest, nil) + if err != nil { + return nil, err + } + go ssh.DiscardRequests(reqs) + return &sshConn{channel: channel, conn: conn}, nil +} + +func (tr *sshTunnelTransporter) Multiplex() bool { + return true +} + +type sshSession struct { + addr string + conn net.Conn + client *ssh.Client + closed chan struct{} + deaded chan struct{} + once sync.Once + connChan chan net.Conn +} + +func (s *sshSession) Ping(interval, timeout time.Duration, retries int) { + if interval <= 0 { + return + } + if timeout <= 0 { + timeout = PingTimeout + } + + if retries == 0 { + retries = 1 + } + + defer close(s.deaded) + + log.Logf("[ssh] ping is enabled, interval: %v, timeout: %v, retry: %d", interval, timeout, retries) + baseCtx := context.Background() + t := time.NewTicker(interval) + defer t.Stop() + + count := retries + 1 + for { + select { + case <-t.C: + start := time.Now() + if Debug { + log.Log("[ssh] sending ping") + } + ctx, cancel := context.WithTimeout(baseCtx, timeout) + var err error + select { + case err = <-s.sendPing(): + case <-ctx.Done(): + err = errors.New("Timeout") + } + cancel() + if err != nil { + log.Log("[ssh] ping:", err) + count-- + if count == 0 { + return + } + continue + } + if Debug { + log.Log("[ssh] ping OK, RTT:", time.Since(start)) + } + count = retries + 1 + case <-s.closed: + return + } + } +} + +func (s *sshSession) sendPing() <-chan error { + ch := make(chan error, 1) + go func() { + if _, _, err := s.client.SendRequest("ping", true, nil); err != nil { + ch <- err + } + close(ch) + }() + return ch +} + +func (s *sshSession) waitServer() error { + defer close(s.closed) + return s.client.Wait() +} + +func (s *sshSession) waitClose() { + defer s.client.Close() + + select { + case <-s.deaded: + case <-s.closed: + } +} + +func (s *sshSession) Closed() bool { + select { + case <-s.deaded: + return true + case <-s.closed: + return true + default: + } + return false +} + +type sshForwardHandler struct { + options *HandlerOptions + config *ssh.ServerConfig +} + +// SSHForwardHandler creates a server Handler for SSH port forwarding server. +func SSHForwardHandler(opts ...HandlerOption) Handler { + h := &sshForwardHandler{} + h.Init(opts...) + + return h +} + +func (h *sshForwardHandler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} + } + + for _, opt := range options { + opt(h.options) + } + h.config = &ssh.ServerConfig{} + + h.config.PasswordCallback = defaultSSHPasswordCallback(h.options.Authenticator) + if h.options.Authenticator == nil { + h.config.NoClientAuth = true + } + tlsConfig := h.options.TLSConfig + if tlsConfig == nil { + tlsConfig = DefaultTLSConfig + } + if tlsConfig != nil && len(tlsConfig.Certificates) > 0 { + signer, err := ssh.NewSignerFromKey(tlsConfig.Certificates[0].PrivateKey) + if err != nil { + log.Log("[ssh-forward]", err) + } + h.config.AddHostKey(signer) + } +} + +func (h *sshForwardHandler) Handle(conn net.Conn) { + sshConn, chans, reqs, err := ssh.NewServerConn(conn, h.config) + if err != nil { + log.Logf("[ssh-forward] %s -> %s : %s", conn.RemoteAddr(), h.options.Node.Addr, err) + conn.Close() + return + } + defer sshConn.Close() + + log.Logf("[ssh-forward] %s <-> %s", conn.RemoteAddr(), h.options.Node.Addr) + h.handleForward(sshConn, chans, reqs) + log.Logf("[ssh-forward] %s >-< %s", conn.RemoteAddr(), h.options.Node.Addr) +} + +func (h *sshForwardHandler) handleForward(conn ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) { + quit := make(chan struct{}) + defer close(quit) // quit signal + + go func() { + for req := range reqs { + switch req.Type { + case RemoteForwardRequest: + go h.tcpipForwardRequest(conn, req, quit) + default: + // log.Log("[ssh] unknown request type:", req.Type, req.WantReply) + if req.WantReply { + req.Reply(false, nil) + } + } + } + }() + + go func() { + for newChannel := range chans { + // Check the type of channel + t := newChannel.ChannelType() + switch t { + case DirectForwardRequest: + channel, requests, err := newChannel.Accept() + if err != nil { + log.Log("[ssh] Could not accept channel:", err) + continue + } + p := directForward{} + ssh.Unmarshal(newChannel.ExtraData(), &p) + + if p.Host1 == "" { + p.Host1 = "" + } + + go ssh.DiscardRequests(requests) + go h.directPortForwardChannel(channel, fmt.Sprintf("%s:%d", p.Host1, p.Port1)) + default: + log.Log("[ssh] Unknown channel type:", t) + newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) + } + } + }() + + conn.Wait() +} + +func (h *sshForwardHandler) directPortForwardChannel(channel ssh.Channel, raddr string) { + defer channel.Close() + + log.Logf("[ssh-tcp] %s - %s", h.options.Node.Addr, raddr) + + if !Can("tcp", raddr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[ssh-tcp] Unauthorized to tcp connect to %s", raddr) + return + } + + if h.options.Bypass.Contains(raddr) { + log.Logf("[ssh-tcp] [bypass] %s", raddr) + return + } + + conn, err := h.options.Chain.Dial(raddr, + RetryChainOption(h.options.Retries), + TimeoutChainOption(h.options.Timeout), + HostsChainOption(h.options.Hosts), + ResolverChainOption(h.options.Resolver), + ) + if err != nil { + log.Logf("[ssh-tcp] %s - %s : %s", h.options.Node.Addr, raddr, err) + return + } + defer conn.Close() + + log.Logf("[ssh-tcp] %s <-> %s", h.options.Node.Addr, raddr) + transport(conn, channel) + log.Logf("[ssh-tcp] %s >-< %s", h.options.Node.Addr, raddr) +} + +// tcpipForward is structure for RFC 4254 7.1 "tcpip-forward" request +type tcpipForward struct { + Host string + Port uint32 +} + +func (h *sshForwardHandler) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Request, quit <-chan struct{}) { + t := tcpipForward{} + ssh.Unmarshal(req.Payload, &t) + + addr := fmt.Sprintf("%s:%d", t.Host, t.Port) + + if !Can("rtcp", addr, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[ssh-rtcp] Unauthorized to tcp bind to %s", addr) + req.Reply(false, nil) + return + } + + ln, err := net.Listen("tcp", addr) //tie to the client connection + if err != nil { + log.Log("[ssh-rtcp]", err) + req.Reply(false, nil) + return + } + defer ln.Close() + + log.Log("[ssh-rtcp] listening on tcp", ln.Addr()) + + replyFunc := func() error { + if t.Port == 0 && req.WantReply { // Client sent port 0. let them know which port is actually being used + _, port, err := getHostPortFromAddr(ln.Addr()) + if err != nil { + return err + } + var b [4]byte + binary.BigEndian.PutUint32(b[:], uint32(port)) + t.Port = uint32(port) + return req.Reply(true, b[:]) + } + return req.Reply(true, nil) + } + if err := replyFunc(); err != nil { + log.Log("[ssh-rtcp]", err) + return + } + + go func() { + for { + conn, err := ln.Accept() + if err != nil { // Unable to accept new connection - listener is likely closed + return + } + + go func(conn net.Conn) { + defer conn.Close() + + p := directForward{} + var err error + + var portnum int + p.Host1 = t.Host + p.Port1 = t.Port + p.Host2, portnum, err = getHostPortFromAddr(conn.RemoteAddr()) + if err != nil { + return + } + + p.Port2 = uint32(portnum) + ch, reqs, err := sshConn.OpenChannel(ForwardedTCPReturnRequest, ssh.Marshal(p)) + if err != nil { + log.Log("[ssh-rtcp] open forwarded channel:", err) + return + } + defer ch.Close() + go ssh.DiscardRequests(reqs) + + log.Logf("[ssh-rtcp] %s <-> %s", conn.RemoteAddr(), conn.LocalAddr()) + transport(ch, conn) + log.Logf("[ssh-rtcp] %s >-< %s", conn.RemoteAddr(), conn.LocalAddr()) + }(conn) + } + }() + + <-quit +} + +// SSHConfig holds the SSH tunnel server config +type SSHConfig struct { + Authenticator Authenticator + TLSConfig *tls.Config + Key ssh.Signer + AuthorizedKeys map[string]bool +} + +type sshTunnelListener struct { + net.Listener + config *ssh.ServerConfig + connChan chan net.Conn + errChan chan error +} + +// SSHTunnelListener creates a Listener for SSH tunnel server. +func SSHTunnelListener(addr string, config *SSHConfig) (Listener, error) { + ln, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + + if config == nil { + config = &SSHConfig{} + } + + sshConfig := &ssh.ServerConfig{ + PasswordCallback: defaultSSHPasswordCallback(config.Authenticator), + PublicKeyCallback: defaultSSHPublicKeyCallback(config.AuthorizedKeys), + } + + if config.Authenticator == nil && len(config.AuthorizedKeys) == 0 { + sshConfig.NoClientAuth = true + } + + signer := config.Key + if signer == nil { + signer, err = ssh.NewSignerFromKey(DefaultTLSConfig.Certificates[0].PrivateKey) + if err != nil { + ln.Close() + return nil, err + } + } + sshConfig.AddHostKey(signer) + + l := &sshTunnelListener{ + Listener: tcpKeepAliveListener{ln.(*net.TCPListener)}, + config: sshConfig, + connChan: make(chan net.Conn, 1024), + errChan: make(chan error, 1), + } + + go l.listenLoop() + + return l, nil +} + +func (l *sshTunnelListener) listenLoop() { + for { + conn, err := l.Listener.Accept() + if err != nil { + log.Log("[ssh] accept:", err) + l.errChan <- err + close(l.errChan) + return + } + go l.serveConn(conn) + } +} + +func (l *sshTunnelListener) serveConn(conn net.Conn) { + sc, chans, reqs, err := ssh.NewServerConn(conn, l.config) + if err != nil { + log.Logf("[ssh] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + conn.Close() + return + } + defer sc.Close() + + go ssh.DiscardRequests(reqs) + go func() { + for newChannel := range chans { + // Check the type of channel + t := newChannel.ChannelType() + switch t { + case GostSSHTunnelRequest: + channel, requests, err := newChannel.Accept() + if err != nil { + log.Log("[ssh] Could not accept channel:", err) + continue + } + go ssh.DiscardRequests(requests) + cc := &sshConn{conn: conn, channel: channel} + select { + case l.connChan <- cc: + default: + cc.Close() + log.Logf("[ssh] %s - %s: connection queue is full", conn.RemoteAddr(), l.Addr()) + } + + default: + log.Log("[ssh] Unknown channel type:", t) + newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) + } + } + }() + + log.Logf("[ssh] %s <-> %s", conn.RemoteAddr(), conn.LocalAddr()) + sc.Wait() + log.Logf("[ssh] %s >-< %s", conn.RemoteAddr(), conn.LocalAddr()) +} + +func (l *sshTunnelListener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.connChan: + case err, ok = <-l.errChan: + if !ok { + err = errors.New("accpet on closed listener") + } + } + return +} + +// directForward is structure for RFC 4254 7.2 - can be used for "forwarded-tcpip" and "direct-tcpip" +type directForward struct { + Host1 string + Port1 uint32 + Host2 string + Port2 uint32 +} + +func (p directForward) String() string { + return fmt.Sprintf("%s:%d -> %s:%d", p.Host2, p.Port2, p.Host1, p.Port1) +} + +func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) { + host, portString, err := net.SplitHostPort(addr.String()) + if err != nil { + return + } + port, err = strconv.Atoi(portString) + return +} + +// PasswordCallbackFunc is a callback function used by SSH server. +// It authenticates user using a password. +type PasswordCallbackFunc func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) + +func defaultSSHPasswordCallback(au Authenticator) PasswordCallbackFunc { + if au == nil { + return nil + } + return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + if au.Authenticate(conn.User(), string(password)) { + return nil, nil + } + log.Logf("[ssh] %s -> %s : password rejected for %s", conn.RemoteAddr(), conn.LocalAddr(), conn.User()) + return nil, fmt.Errorf("password rejected for %s", conn.User()) + } +} + +// PublicKeyCallbackFunc is a callback function used by SSH server. +// It offers a public key for authentication. +type PublicKeyCallbackFunc func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) + +func defaultSSHPublicKeyCallback(keys map[string]bool) PublicKeyCallbackFunc { + if len(keys) == 0 { + return nil + } + + return func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { + if keys[string(pubKey.Marshal())] { + return &ssh.Permissions{ + // Record the public key used for authentication. + Extensions: map[string]string{ + "pubkey-fp": ssh.FingerprintSHA256(pubKey), + }, + }, nil + } + return nil, fmt.Errorf("unknown public key for %q", c.User()) + } +} + +type sshNopConn struct { + session *sshSession +} + +func (c *sshNopConn) Read(b []byte) (n int, err error) { + return 0, &net.OpError{Op: "read", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("read not supported")} +} + +func (c *sshNopConn) Write(b []byte) (n int, err error) { + return 0, &net.OpError{Op: "write", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("write not supported")} +} + +func (c *sshNopConn) Close() error { + return nil +} + +func (c *sshNopConn) LocalAddr() net.Addr { + return &net.TCPAddr{ + IP: net.IPv4zero, + Port: 0, + } +} + +func (c *sshNopConn) RemoteAddr() net.Addr { + return &net.TCPAddr{ + IP: net.IPv4zero, + Port: 0, + } +} + +func (c *sshNopConn) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *sshNopConn) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *sshNopConn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +type sshConn struct { + channel ssh.Channel + conn net.Conn +} + +func (c *sshConn) Read(b []byte) (n int, err error) { + return c.channel.Read(b) +} + +func (c *sshConn) Write(b []byte) (n int, err error) { + return c.channel.Write(b) +} + +func (c *sshConn) Close() error { + return c.channel.Close() +} + +func (c *sshConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *sshConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *sshConn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *sshConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *sshConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} diff --git a/ssh_test.go b/ssh_test.go new file mode 100644 index 0000000..20d24f0 --- /dev/null +++ b/ssh_test.go @@ -0,0 +1,581 @@ +package gost + +import ( + "crypto/rand" + "crypto/tls" + "fmt" + "net" + "net/http/httptest" + "net/url" + "testing" +) + +func sshDirectForwardRoundtrip(targetURL string, data []byte) error { + ln, err := TCPListener("") + if err != nil { + return err + } + + client := &Client{ + Connector: SSHDirectForwardConnector(), + Transporter: SSHForwardTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SSHForwardHandler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSSHDirectForward(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := sshDirectForwardRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} + +func BenchmarkSSHDirectForward(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: SSHDirectForwardConnector(), + Transporter: SSHForwardTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SSHForwardHandler(), + } + + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkSSHDirectForwardParallel(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TCPListener("") + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: SSHDirectForwardConnector(), + Transporter: SSHForwardTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SSHForwardHandler(), + } + + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } + }) +} + +func sshRemoteForwardRoundtrip(t *testing.T, targetURL string, data []byte) (err error) { + ln, err := TCPListener("") + if err != nil { + return + } + + client := &Client{ + Connector: SSHRemoteForwardConnector(), + Transporter: SSHForwardTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SSHForwardHandler(), + } + + go server.Run() + defer server.Close() + + conn, err := proxyConn(client, server) + if err != nil { + return + } + defer conn.Close() + + go func() { + conn, err = client.Connect(conn, ":0") + if err != nil { + return + } + }() + + c, err := net.Dial("tcp", conn.LocalAddr().String()) + if err != nil { + return + } + defer c.Close() + + u, err := url.Parse(targetURL) + if err != nil { + return + } + + cc, err := net.Dial("tcp", u.Host) + if err != nil { + return + } + defer cc.Close() + + go transport(conn, cc) + + t.Log("httpRoundtrip") + return httpRoundtrip(c, targetURL, data) +} + +// TODO: fix this test +func _TestSSHRemoteForward(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := sshRemoteForwardRoundtrip(t, httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} + +func httpOverSSHTunnelRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := SSHTunnelListener("", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: HTTPConnector(clientInfo), + Transporter: SSHTunnelTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestHTTPOverSSHTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range httpProxyTests { + err := httpOverSSHTunnelRoundtrip(httpSrv.URL, sendData, nil, tc.cliUser, tc.srvUsers) + if err == nil { + if tc.errStr != "" { + t.Errorf("#%d should failed with error %s", i, tc.errStr) + } + } else { + if tc.errStr == "" { + t.Errorf("#%d got error %v", i, err) + } + if err.Error() != tc.errStr { + t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) + } + } + } +} + +func BenchmarkHTTPOverSSHTunnel(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := SSHTunnelListener("", nil) + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: SSHTunnelTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkHTTPOverSSHTunnelParallel(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := SSHTunnelListener("", nil) + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: SSHTunnelTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } + }) +} + +func socks5OverSSHTunnelRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := SSHTunnelListener("", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS5Connector(clientInfo), + Transporter: SSHTunnelTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS5Handler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS5OverSSHTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range socks5ProxyTests { + err := socks5OverSSHTunnelRoundtrip(httpSrv.URL, sendData, + nil, + tc.cliUser, + tc.srvUsers, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func socks4OverSSHTunnelRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config) error { + ln, err := SSHTunnelListener("", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4Connector(), + Transporter: SSHTunnelTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4OverSSHTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4OverSSHTunnelRoundtrip(httpSrv.URL, sendData, nil) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func socks4aOverSSHTunnelRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config) error { + ln, err := SSHTunnelListener("", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4AConnector(), + Transporter: SSHTunnelTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4AOverSSHTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4aOverSSHTunnelRoundtrip(httpSrv.URL, sendData, nil) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func ssOverSSHTunnelRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config, + clientInfo, serverInfo *url.Userinfo) error { + + ln, err := SSHTunnelListener("", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: ShadowConnector(clientInfo), + Transporter: SSHTunnelTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: ShadowHandler( + UsersHandlerOption(serverInfo), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSSOverSSHTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range ssProxyTests { + err := ssOverSSHTunnelRoundtrip(httpSrv.URL, sendData, + nil, + tc.clientCipher, + tc.serverCipher, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func sniOverSSHTunnelRoundtrip(targetURL string, data []byte, host string) error { + ln, err := SSHTunnelListener("", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: SSHTunnelTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverSSHTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverSSHTunnelRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} + +func sshForwardTunnelRoundtrip(targetURL string, data []byte) error { + ln, err := SSHTunnelListener("", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: SSHTunnelTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + server.Handler.Init() + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSSHForwardTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := sshForwardTunnelRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} diff --git a/tcp.go b/tcp.go new file mode 100644 index 0000000..a255011 --- /dev/null +++ b/tcp.go @@ -0,0 +1,66 @@ +package gost + +import "net" + +// tcpTransporter is a raw TCP transporter. +type tcpTransporter struct{} + +// TCPTransporter creates a raw TCP client. +func TCPTransporter() Transporter { + return &tcpTransporter{} +} + +func (tr *tcpTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) { + opts := &DialOptions{} + for _, option := range options { + option(opts) + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = DialTimeout + } + if opts.Chain == nil { + return net.DialTimeout("tcp", addr, timeout) + } + return opts.Chain.Dial(addr) +} + +func (tr *tcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + return conn, nil +} + +func (tr *tcpTransporter) Multiplex() bool { + return false +} + +type tcpListener struct { + net.Listener +} + +// TCPListener creates a Listener for TCP proxy server. +func TCPListener(addr string) (Listener, error) { + laddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + ln, err := net.ListenTCP("tcp", laddr) + if err != nil { + return nil, err + } + return &tcpListener{Listener: tcpKeepAliveListener{ln}}, nil +} + +type tcpKeepAliveListener struct { + *net.TCPListener +} + +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(KeepAliveTime) + return tc, nil +} diff --git a/tls.go b/tls.go new file mode 100644 index 0000000..8526c6f --- /dev/null +++ b/tls.go @@ -0,0 +1,328 @@ +package gost + +import ( + "crypto/tls" + "errors" + "net" + "sync" + "time" + + "github.com/go-log/log" + + smux "github.com/xtaci/smux" +) + +type tlsTransporter struct { + tcpTransporter +} + +// TLSTransporter creates a Transporter that is used by TLS proxy client. +func TLSTransporter() Transporter { + return &tlsTransporter{} +} + +func (tr *tlsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + if opts.TLSConfig == nil { + opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = HandshakeTimeout + } + + return wrapTLSClient(conn, opts.TLSConfig, timeout) +} + +type mtlsTransporter struct { + tcpTransporter + sessions map[string]*muxSession + sessionMutex sync.Mutex +} + +// MTLSTransporter creates a Transporter that is used by multiplex-TLS proxy client. +func MTLSTransporter() Transporter { + return &mtlsTransporter{ + sessions: make(map[string]*muxSession), + } +} + +func (tr *mtlsTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { + opts := &DialOptions{} + for _, option := range options { + option(opts) + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + session, ok := tr.sessions[addr] + if session != nil && session.IsClosed() { + delete(tr.sessions, addr) + ok = false // session is dead + } + if !ok { + timeout := opts.Timeout + if timeout <= 0 { + timeout = DialTimeout + } + + if opts.Chain == nil { + conn, err = net.DialTimeout("tcp", addr, timeout) + } else { + conn, err = opts.Chain.Dial(addr) + } + if err != nil { + return + } + session = &muxSession{conn: conn} + tr.sessions[addr] = session + } + return session.conn, nil +} + +func (tr *mtlsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = HandshakeTimeout + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + + session, ok := tr.sessions[opts.Addr] + if !ok || session.session == nil { + s, err := tr.initSession(opts.Addr, conn, opts) + if err != nil { + conn.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + session = s + tr.sessions[opts.Addr] = session + } + cc, err := session.GetConn() + if err != nil { + session.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + + return cc, nil +} + +func (tr *mtlsTransporter) initSession(addr string, conn net.Conn, opts *HandshakeOptions) (*muxSession, error) { + if opts == nil { + opts = &HandshakeOptions{} + } + if opts.TLSConfig == nil { + opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} + } + conn, err := wrapTLSClient(conn, opts.TLSConfig, opts.Timeout) + if err != nil { + return nil, err + } + + // stream multiplex + smuxConfig := smux.DefaultConfig() + session, err := smux.Client(conn, smuxConfig) + if err != nil { + return nil, err + } + return &muxSession{conn: conn, session: session}, nil +} + +func (tr *mtlsTransporter) Multiplex() bool { + return true +} + +type tlsListener struct { + net.Listener +} + +// TLSListener creates a Listener for TLS proxy server. +func TLSListener(addr string, config *tls.Config) (Listener, error) { + if config == nil { + config = DefaultTLSConfig + } + ln, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + + ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config) + return &tlsListener{ln}, nil +} + +type mtlsListener struct { + ln net.Listener + connChan chan net.Conn + errChan chan error +} + +// MTLSListener creates a Listener for multiplex-TLS proxy server. +func MTLSListener(addr string, config *tls.Config) (Listener, error) { + if config == nil { + config = DefaultTLSConfig + } + ln, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + + l := &mtlsListener{ + ln: tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config), + connChan: make(chan net.Conn, 1024), + errChan: make(chan error, 1), + } + go l.listenLoop() + + return l, nil +} + +func (l *mtlsListener) listenLoop() { + for { + conn, err := l.ln.Accept() + if err != nil { + log.Log("[mtls] accept:", err) + l.errChan <- err + close(l.errChan) + return + } + go l.mux(conn) + } +} + +func (l *mtlsListener) mux(conn net.Conn) { + log.Logf("[mtls] %s - %s", conn.RemoteAddr(), l.Addr()) + smuxConfig := smux.DefaultConfig() + mux, err := smux.Server(conn, smuxConfig) + if err != nil { + log.Logf("[mtls] %s - %s : %s", conn.RemoteAddr(), l.Addr(), err) + return + } + defer mux.Close() + + log.Logf("[mtls] %s <-> %s", conn.RemoteAddr(), l.Addr()) + defer log.Logf("[mtls] %s >-< %s", conn.RemoteAddr(), l.Addr()) + + for { + stream, err := mux.AcceptStream() + if err != nil { + log.Log("[mtls] accept stream:", err) + return + } + + cc := &muxStreamConn{Conn: conn, stream: stream} + select { + case l.connChan <- cc: + default: + cc.Close() + log.Logf("[mtls] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr()) + } + } +} + +func (l *mtlsListener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.connChan: + case err, ok = <-l.errChan: + if !ok { + err = errors.New("accpet on closed listener") + } + } + return +} +func (l *mtlsListener) Addr() net.Addr { + return l.ln.Addr() +} + +func (l *mtlsListener) Close() error { + return l.ln.Close() +} + +// Wrap a net.Conn into a client tls connection, performing any +// additional verification as needed. +// +// As of go 1.3, crypto/tls only supports either doing no certificate +// verification, or doing full verification including of the peer's +// DNS name. For consul, we want to validate that the certificate is +// signed by a known CA, but because consul doesn't use DNS names for +// node names, we don't verify the certificate DNS names. Since go 1.3 +// no longer supports this mode of operation, we have to do it +// manually. +// +// This code is taken from consul: +// https://github.com/hashicorp/consul/blob/master/tlsutil/config.go +func wrapTLSClient(conn net.Conn, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) { + var err error + var tlsConn *tls.Conn + + if timeout <= 0 { + timeout = HandshakeTimeout // default timeout + } + + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + + tlsConn = tls.Client(conn, tlsConfig) + + // Otherwise perform handshake, but don't verify the domain + // + // The following is lightly-modified from the doFullHandshake + // method in https://golang.org/src/crypto/tls/handshake_client.go + if err = tlsConn.Handshake(); err != nil { + tlsConn.Close() + return nil, err + } + + // We can do this in `tls.Config.VerifyConnection`, which effective for + // other TLS protocols such as WebSocket. See `route.go:parseChainNode` + /* + // If crypto/tls is doing verification, there's no need to do our own. + if tlsConfig.InsecureSkipVerify == false { + return tlsConn, nil + } + + // Similarly if we use host's CA, we can do full handshake + if tlsConfig.RootCAs == nil { + return tlsConn, nil + } + + opts := x509.VerifyOptions{ + Roots: tlsConfig.RootCAs, + CurrentTime: time.Now(), + DNSName: "", + Intermediates: x509.NewCertPool(), + } + + certs := tlsConn.ConnectionState().PeerCertificates + for i, cert := range certs { + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) + } + + _, err = certs[0].Verify(opts) + if err != nil { + tlsConn.Close() + return nil, err + } + */ + + return tlsConn, err +} diff --git a/tls_test.go b/tls_test.go new file mode 100644 index 0000000..1d74ce8 --- /dev/null +++ b/tls_test.go @@ -0,0 +1,810 @@ +package gost + +import ( + "crypto/rand" + "crypto/tls" + "fmt" + "net/http/httptest" + "net/url" + "testing" +) + +func httpOverTLSRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := TLSListener("", tlsConfig) + if err != nil { + return err + } + + client := &Client{ + Connector: HTTPConnector(clientInfo), + Transporter: TLSTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestHTTPOverTLS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range httpProxyTests { + err := httpOverTLSRoundtrip(httpSrv.URL, sendData, nil, tc.cliUser, tc.srvUsers) + if err == nil { + if tc.errStr != "" { + t.Errorf("#%d should failed with error %s", i, tc.errStr) + } + } else { + if tc.errStr == "" { + t.Errorf("#%d got error %v", i, err) + } + if err.Error() != tc.errStr { + t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) + } + } + } +} + +func BenchmarkHTTPOverTLS(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TLSListener("", nil) + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: TLSTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkHTTPOverTLSParallel(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := TLSListener("", nil) + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: TLSTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } + }) +} + +func socks5OverTLSRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := TLSListener("", tlsConfig) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS5Connector(clientInfo), + Transporter: TLSTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS5Handler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS5OverTLS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range socks5ProxyTests { + err := socks5OverTLSRoundtrip(httpSrv.URL, sendData, + nil, + tc.cliUser, + tc.srvUsers, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func socks4OverTLSRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config) error { + ln, err := TLSListener("", tlsConfig) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4Connector(), + Transporter: TLSTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4OverTLS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4OverTLSRoundtrip(httpSrv.URL, sendData, nil) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func socks4aOverTLSRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config) error { + ln, err := TLSListener("", tlsConfig) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4AConnector(), + Transporter: TLSTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4AOverTLS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4aOverTLSRoundtrip(httpSrv.URL, sendData, nil) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func ssOverTLSRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config, + clientInfo, serverInfo *url.Userinfo) error { + + ln, err := TLSListener("", tlsConfig) + if err != nil { + return err + } + + client := &Client{ + Connector: ShadowConnector(clientInfo), + Transporter: TLSTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: ShadowHandler( + UsersHandlerOption(serverInfo), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSSOverTLS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range ssProxyTests { + err := ssOverTLSRoundtrip(httpSrv.URL, sendData, + nil, + tc.clientCipher, + tc.serverCipher, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func sniOverTLSRoundtrip(targetURL string, data []byte, host string) error { + ln, err := TLSListener("", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: TLSTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverTLS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverTLSRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} + +func tlsForwardTunnelRoundtrip(targetURL string, data []byte) error { + ln, err := TLSListener("", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: TLSTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + server.Handler.Init() + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestTLSForwardTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := tlsForwardTunnelRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} + +func httpOverMTLSRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := MTLSListener("", tlsConfig) + if err != nil { + return err + } + + client := &Client{ + Connector: HTTPConnector(clientInfo), + Transporter: MTLSTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestHTTPOverMTLS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range httpProxyTests { + err := httpOverMTLSRoundtrip(httpSrv.URL, sendData, nil, tc.cliUser, tc.srvUsers) + if err == nil { + if tc.errStr != "" { + t.Errorf("#%d should failed with error %s", i, tc.errStr) + } + } else { + if tc.errStr == "" { + t.Errorf("#%d got error %v", i, err) + } + if err.Error() != tc.errStr { + t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) + } + } + } +} + +func BenchmarkHTTPOverMTLS(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := MTLSListener("", nil) + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: MTLSTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkHTTPOverMTLSParallel(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := MTLSListener("", nil) + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: MTLSTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } + }) +} + +func socks5OverMTLSRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := MTLSListener("", tlsConfig) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS5Connector(clientInfo), + Transporter: MTLSTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS5Handler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS5OverMTLS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range socks5ProxyTests { + err := socks5OverMTLSRoundtrip(httpSrv.URL, sendData, + nil, + tc.cliUser, + tc.srvUsers, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func socks4OverMTLSRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config) error { + ln, err := MTLSListener("", tlsConfig) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4Connector(), + Transporter: MTLSTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4OverMTLS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4OverMTLSRoundtrip(httpSrv.URL, sendData, nil) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func socks4aOverMTLSRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config) error { + ln, err := MTLSListener("", tlsConfig) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4AConnector(), + Transporter: MTLSTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4AOverMTLS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4aOverMTLSRoundtrip(httpSrv.URL, sendData, nil) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func ssOverMTLSRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config, + clientInfo, serverInfo *url.Userinfo) error { + + ln, err := MTLSListener("", tlsConfig) + if err != nil { + return err + } + + client := &Client{ + Connector: ShadowConnector(clientInfo), + Transporter: MTLSTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: ShadowHandler( + UsersHandlerOption(serverInfo), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSSOverMTLS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range ssProxyTests { + err := ssOverMTLSRoundtrip(httpSrv.URL, sendData, + nil, + tc.clientCipher, + tc.serverCipher, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func sniOverMTLSRoundtrip(targetURL string, data []byte, host string) error { + ln, err := MTLSListener("", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: MTLSTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverMTLS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverMTLSRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} + +func mtlsForwardTunnelRoundtrip(targetURL string, data []byte) error { + ln, err := MTLSListener("", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: MTLSTransporter(), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + server.Handler.Init() + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestMTLSForwardTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := mtlsForwardTunnelRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} diff --git a/tuntap.go b/tuntap.go new file mode 100644 index 0000000..e38ab63 --- /dev/null +++ b/tuntap.go @@ -0,0 +1,815 @@ +package gost + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "os" + "sync" + "time" + + "github.com/go-log/log" + "github.com/shadowsocks/go-shadowsocks2/core" + "github.com/shadowsocks/go-shadowsocks2/shadowaead" + "github.com/songgao/water" + "github.com/songgao/water/waterutil" + "github.com/xtaci/tcpraw" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +var mIPProts = map[waterutil.IPProtocol]string{ + waterutil.HOPOPT: "HOPOPT", + waterutil.ICMP: "ICMP", + waterutil.IGMP: "IGMP", + waterutil.GGP: "GGP", + waterutil.TCP: "TCP", + waterutil.UDP: "UDP", + waterutil.IPv6_Route: "IPv6-Route", + waterutil.IPv6_Frag: "IPv6-Frag", + waterutil.IPv6_ICMP: "IPv6-ICMP", +} + +func ipProtocol(p waterutil.IPProtocol) string { + if v, ok := mIPProts[p]; ok { + return v + } + return fmt.Sprintf("unknown(%d)", p) +} + +// IPRoute is an IP routing entry. +type IPRoute struct { + Dest *net.IPNet + Gateway net.IP +} + +// TunConfig is the config for TUN device. +type TunConfig struct { + Name string + Addr string + Peer string // peer addr of point-to-point on MacOS + MTU int + Routes []IPRoute + Gateway string +} + +type tunRouteKey [16]byte + +func ipToTunRouteKey(ip net.IP) (key tunRouteKey) { + copy(key[:], ip.To16()) + return +} + +type tunListener struct { + addr net.Addr + conns chan net.Conn + closed chan struct{} + config TunConfig +} + +// TunListener creates a listener for tun tunnel. +func TunListener(cfg TunConfig) (Listener, error) { + threads := 1 + ln := &tunListener{ + conns: make(chan net.Conn, threads), + closed: make(chan struct{}), + config: cfg, + } + + for i := 0; i < threads; i++ { + conn, ifce, err := createTun(cfg) + if err != nil { + return nil, err + } + ln.addr = conn.LocalAddr() + + addrs, _ := ifce.Addrs() + log.Logf("[tun] %s: name: %s, mtu: %d, addrs: %s", + conn.LocalAddr(), ifce.Name, ifce.MTU, addrs) + + ln.conns <- conn + } + + return ln, nil +} + +func (l *tunListener) Accept() (net.Conn, error) { + select { + case conn := <-l.conns: + return conn, nil + case <-l.closed: + } + + return nil, errors.New("accept on closed listener") +} + +func (l *tunListener) Addr() net.Addr { + return l.addr +} + +func (l *tunListener) Close() error { + select { + case <-l.closed: + return errors.New("listener has been closed") + default: + close(l.closed) + } + return nil +} + +type tunHandler struct { + options *HandlerOptions + routes sync.Map + chExit chan struct{} +} + +// TunHandler creates a handler for tun tunnel. +func TunHandler(opts ...HandlerOption) Handler { + h := &tunHandler{ + options: &HandlerOptions{}, + chExit: make(chan struct{}, 1), + } + for _, opt := range opts { + opt(h.options) + } + + return h +} + +func (h *tunHandler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} + } + for _, opt := range options { + opt(h.options) + } +} + +func (h *tunHandler) Handle(conn net.Conn) { + defer os.Exit(0) + defer conn.Close() + + var err error + var raddr net.Addr + if addr := h.options.Node.Remote; addr != "" { + raddr, err = net.ResolveUDPAddr("udp", addr) + if err != nil { + log.Logf("[tun] %s: remote addr: %v", conn.LocalAddr(), err) + return + } + } + + var tempDelay time.Duration + for { + err := func() error { + var err error + var pc net.PacketConn + // fake tcp mode will be ignored when the client specifies a chain. + if raddr != nil && !h.options.Chain.IsEmpty() { + cc, err := h.options.Chain.DialContext(context.Background(), "udp", raddr.String()) + if err != nil { + return err + } + var ok bool + pc, ok = cc.(net.PacketConn) + if !ok { + err = errors.New("not a packet connection") + log.Logf("[tun] %s - %s: %s", conn.LocalAddr(), raddr, err) + return err + } + } else { + if h.options.TCPMode { + if raddr != nil { + pc, err = tcpraw.Dial("tcp", raddr.String()) + } else { + pc, err = tcpraw.Listen("tcp", h.options.Node.Addr) + } + } else { + laddr, _ := net.ResolveUDPAddr("udp", h.options.Node.Addr) + pc, err = net.ListenUDP("udp", laddr) + } + } + if err != nil { + return err + } + + pc, err = h.initTunnelConn(pc) + if err != nil { + return err + } + + return h.transportTun(conn, pc, raddr) + }() + if err != nil { + log.Logf("[tun] %s: %v", conn.LocalAddr(), err) + } + + select { + case <-h.chExit: + return + default: + } + + if err != nil { + if tempDelay == 0 { + tempDelay = 1000 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 6 * time.Second; tempDelay > max { + tempDelay = max + } + time.Sleep(tempDelay) + continue + } + tempDelay = 0 + } +} + +func (h *tunHandler) initTunnelConn(pc net.PacketConn) (net.PacketConn, error) { + if len(h.options.Users) > 0 && h.options.Users[0] != nil { + passwd, _ := h.options.Users[0].Password() + cipher, err := core.PickCipher(h.options.Users[0].Username(), nil, passwd) + if err != nil { + return nil, err + } + pc = cipher.PacketConn(pc) + } + return pc, nil +} + +func (h *tunHandler) findRouteFor(dst net.IP) net.Addr { + if v, ok := h.routes.Load(ipToTunRouteKey(dst)); ok { + return v.(net.Addr) + } + for _, route := range h.options.IPRoutes { + if route.Dest.Contains(dst) && route.Gateway != nil { + if v, ok := h.routes.Load(ipToTunRouteKey(route.Gateway)); ok { + return v.(net.Addr) + } + } + } + return nil +} + +func (h *tunHandler) transportTun(tun net.Conn, conn net.PacketConn, raddr net.Addr) error { + errc := make(chan error, 1) + + go func() { + for { + err := func() error { + b := sPool.Get().([]byte) + defer sPool.Put(b) + + n, err := tun.Read(b) + if err != nil { + select { + case h.chExit <- struct{}{}: + default: + } + return err + } + + var src, dst net.IP + if waterutil.IsIPv4(b[:n]) { + header, err := ipv4.ParseHeader(b[:n]) + if err != nil { + log.Logf("[tun] %s: %v", tun.LocalAddr(), err) + return nil + } + if Debug { + log.Logf("[tun] %s -> %s %-4s %d/%-4d %-4x %d", + header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol(b[:n])), + header.Len, header.TotalLen, header.ID, header.Flags) + } + src, dst = header.Src, header.Dst + } else if waterutil.IsIPv6(b[:n]) { + header, err := ipv6.ParseHeader(b[:n]) + if err != nil { + log.Logf("[tun] %s: %v", tun.LocalAddr(), err) + return nil + } + if Debug { + log.Logf("[tun] %s -> %s %s %d %d", + header.Src, header.Dst, + ipProtocol(waterutil.IPProtocol(header.NextHeader)), + header.PayloadLen, header.TrafficClass) + } + src, dst = header.Src, header.Dst + } else { + log.Logf("[tun] unknown packet") + return nil + } + + // client side, deliver packet directly. + if raddr != nil { + _, err := conn.WriteTo(b[:n], raddr) + return err + } + + addr := h.findRouteFor(dst) + if addr == nil { + log.Logf("[tun] no route for %s -> %s", src, dst) + return nil + } + + if Debug { + log.Logf("[tun] find route: %s -> %s", dst, addr) + } + if _, err := conn.WriteTo(b[:n], addr); err != nil { + return err + } + return nil + }() + + if err != nil { + errc <- err + return + } + } + }() + + go func() { + for { + err := func() error { + b := sPool.Get().([]byte) + defer sPool.Put(b) + + n, addr, err := conn.ReadFrom(b) + if err != nil && + err != shadowaead.ErrShortPacket { + return err + } + + var src, dst net.IP + if waterutil.IsIPv4(b[:n]) { + header, err := ipv4.ParseHeader(b[:n]) + if err != nil { + log.Logf("[tun] %s: %v", tun.LocalAddr(), err) + return nil + } + if Debug { + log.Logf("[tun] %s -> %s %-4s %d/%-4d %-4x %d", + header.Src, header.Dst, ipProtocol(waterutil.IPv4Protocol(b[:n])), + header.Len, header.TotalLen, header.ID, header.Flags) + } + src, dst = header.Src, header.Dst + } else if waterutil.IsIPv6(b[:n]) { + header, err := ipv6.ParseHeader(b[:n]) + if err != nil { + log.Logf("[tun] %s: %v", tun.LocalAddr(), err) + return nil + } + if Debug { + log.Logf("[tun] %s -> %s %s %d %d", + header.Src, header.Dst, + ipProtocol(waterutil.IPProtocol(header.NextHeader)), + header.PayloadLen, header.TrafficClass) + } + src, dst = header.Src, header.Dst + } else { + log.Logf("[tun] unknown packet") + return nil + } + + // client side, deliver packet to tun device. + if raddr != nil { + _, err := tun.Write(b[:n]) + return err + } + + rkey := ipToTunRouteKey(src) + if actual, loaded := h.routes.LoadOrStore(rkey, addr); loaded { + if actual.(net.Addr).String() != addr.String() { + log.Logf("[tun] update route: %s -> %s (old %s)", + src, addr, actual.(net.Addr)) + h.routes.Store(rkey, addr) + } + } else { + log.Logf("[tun] new route: %s -> %s", src, addr) + } + + if addr := h.findRouteFor(dst); addr != nil { + if Debug { + log.Logf("[tun] find route: %s -> %s", dst, addr) + } + _, err := conn.WriteTo(b[:n], addr) + return err + } + + if _, err := tun.Write(b[:n]); err != nil { + select { + case h.chExit <- struct{}{}: + default: + } + return err + } + return nil + }() + + if err != nil { + errc <- err + return + } + } + }() + + err := <-errc + if err != nil && err == io.EOF { + err = nil + } + return err +} + +var mEtherTypes = map[waterutil.Ethertype]string{ + waterutil.IPv4: "ip", + waterutil.ARP: "arp", + waterutil.RARP: "rarp", + waterutil.IPv6: "ip6", +} + +func etherType(et waterutil.Ethertype) string { + if s, ok := mEtherTypes[et]; ok { + return s + } + return fmt.Sprintf("unknown(%v)", et) +} + +// TapConfig is the config for TAP device. +type TapConfig struct { + Name string + Addr string + MTU int + Routes []string + Gateway string +} + +type tapRouteKey [6]byte + +func hwAddrToTapRouteKey(addr net.HardwareAddr) (key tapRouteKey) { + copy(key[:], addr) + return +} + +type tapListener struct { + addr net.Addr + conns chan net.Conn + closed chan struct{} + config TapConfig +} + +// TapListener creates a listener for tap tunnel. +func TapListener(cfg TapConfig) (Listener, error) { + threads := 1 + ln := &tapListener{ + conns: make(chan net.Conn, threads), + closed: make(chan struct{}), + config: cfg, + } + + for i := 0; i < threads; i++ { + conn, ifce, err := createTap(cfg) + if err != nil { + return nil, err + } + ln.addr = conn.LocalAddr() + + addrs, _ := ifce.Addrs() + log.Logf("[tap] %s: name: %s, mac: %s, mtu: %d, addrs: %s", + conn.LocalAddr(), ifce.Name, ifce.HardwareAddr, ifce.MTU, addrs) + + ln.conns <- conn + } + return ln, nil +} + +func (l *tapListener) Accept() (net.Conn, error) { + select { + case conn := <-l.conns: + return conn, nil + case <-l.closed: + } + + return nil, errors.New("accept on closed listener") +} + +func (l *tapListener) Addr() net.Addr { + return l.addr +} + +func (l *tapListener) Close() error { + select { + case <-l.closed: + return errors.New("listener has been closed") + default: + close(l.closed) + } + return nil +} + +type tapHandler struct { + options *HandlerOptions + routes sync.Map + chExit chan struct{} +} + +// TapHandler creates a handler for tap tunnel. +func TapHandler(opts ...HandlerOption) Handler { + h := &tapHandler{ + options: &HandlerOptions{}, + chExit: make(chan struct{}, 1), + } + for _, opt := range opts { + opt(h.options) + } + + return h +} + +func (h *tapHandler) Init(options ...HandlerOption) { + if h.options == nil { + h.options = &HandlerOptions{} + } + for _, opt := range options { + opt(h.options) + } +} + +func (h *tapHandler) Handle(conn net.Conn) { + defer os.Exit(0) + defer conn.Close() + + var err error + var raddr net.Addr + if addr := h.options.Node.Remote; addr != "" { + raddr, err = net.ResolveUDPAddr("udp", addr) + if err != nil { + log.Logf("[tap] %s: remote addr: %v", conn.LocalAddr(), err) + return + } + } + + var tempDelay time.Duration + for { + err := func() error { + var err error + var pc net.PacketConn + // fake tcp mode will be ignored when the client specifies a chain. + if raddr != nil && !h.options.Chain.IsEmpty() { + cc, err := h.options.Chain.DialContext(context.Background(), "udp", raddr.String()) + if err != nil { + return err + } + var ok bool + pc, ok = cc.(net.PacketConn) + if !ok { + err = errors.New("not a packet connection") + log.Logf("[tap] %s - %s: %s", conn.LocalAddr(), raddr, err) + return err + } + } else { + if h.options.TCPMode { + if raddr != nil { + pc, err = tcpraw.Dial("tcp", raddr.String()) + } else { + pc, err = tcpraw.Listen("tcp", h.options.Node.Addr) + } + } else { + laddr, _ := net.ResolveUDPAddr("udp", h.options.Node.Addr) + pc, err = net.ListenUDP("udp", laddr) + } + } + if err != nil { + return err + } + + pc, err = h.initTunnelConn(pc) + if err != nil { + return err + } + + return h.transportTap(conn, pc, raddr) + }() + if err != nil { + log.Logf("[tap] %s: %v", conn.LocalAddr(), err) + } + + select { + case <-h.chExit: + return + default: + } + + if err != nil { + if tempDelay == 0 { + tempDelay = 1000 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 6 * time.Second; tempDelay > max { + tempDelay = max + } + time.Sleep(tempDelay) + continue + } + tempDelay = 0 + } +} + +func (h *tapHandler) initTunnelConn(pc net.PacketConn) (net.PacketConn, error) { + if len(h.options.Users) > 0 && h.options.Users[0] != nil { + passwd, _ := h.options.Users[0].Password() + cipher, err := core.PickCipher(h.options.Users[0].Username(), nil, passwd) + if err != nil { + return nil, err + } + pc = cipher.PacketConn(pc) + } + return pc, nil +} + +func (h *tapHandler) transportTap(tap net.Conn, conn net.PacketConn, raddr net.Addr) error { + errc := make(chan error, 1) + + go func() { + for { + err := func() error { + b := sPool.Get().([]byte) + defer sPool.Put(b) + + n, err := tap.Read(b) + if err != nil { + select { + case h.chExit <- struct{}{}: + default: + } + return err + } + + src := waterutil.MACSource(b[:n]) + dst := waterutil.MACDestination(b[:n]) + eType := etherType(waterutil.MACEthertype(b[:n])) + + if Debug { + log.Logf("[tap] %s -> %s %s %d", src, dst, eType, n) + } + + // client side, deliver frame directly. + if raddr != nil { + _, err := conn.WriteTo(b[:n], raddr) + return err + } + + // server side, broadcast. + if waterutil.IsBroadcast(dst) { + go h.routes.Range(func(k, v interface{}) bool { + conn.WriteTo(b[:n], v.(net.Addr)) + return true + }) + return nil + } + + var addr net.Addr + if v, ok := h.routes.Load(hwAddrToTapRouteKey(dst)); ok { + addr = v.(net.Addr) + } + if addr == nil { + log.Logf("[tap] no route for %s -> %s %s %d", src, dst, eType, n) + return nil + } + + if _, err := conn.WriteTo(b[:n], addr); err != nil { + return err + } + return nil + }() + + if err != nil { + errc <- err + return + } + } + }() + + go func() { + for { + err := func() error { + b := sPool.Get().([]byte) + defer sPool.Put(b) + + n, addr, err := conn.ReadFrom(b) + if err != nil && + err != shadowaead.ErrShortPacket { + return err + } + + src := waterutil.MACSource(b[:n]) + dst := waterutil.MACDestination(b[:n]) + eType := etherType(waterutil.MACEthertype(b[:n])) + + if Debug { + log.Logf("[tap] %s -> %s %s %d", src, dst, eType, n) + } + + // client side, deliver frame to tap device. + if raddr != nil { + _, err := tap.Write(b[:n]) + return err + } + + // server side, record route. + rkey := hwAddrToTapRouteKey(src) + if actual, loaded := h.routes.LoadOrStore(rkey, addr); loaded { + if actual.(net.Addr).String() != addr.String() { + log.Logf("[tap] update route: %s -> %s (old %s)", + src, addr, actual.(net.Addr)) + h.routes.Store(rkey, addr) + } + } else { + log.Logf("[tap] new route: %s -> %s", src, addr) + } + + if waterutil.IsBroadcast(dst) { + go h.routes.Range(func(k, v interface{}) bool { + if k.(tapRouteKey) != rkey { + conn.WriteTo(b[:n], v.(net.Addr)) + } + return true + }) + } + + if v, ok := h.routes.Load(hwAddrToTapRouteKey(dst)); ok { + if Debug { + log.Logf("[tap] find route: %s -> %s", dst, v) + } + _, err := conn.WriteTo(b[:n], v.(net.Addr)) + return err + } + + if _, err := tap.Write(b[:n]); err != nil { + select { + case h.chExit <- struct{}{}: + default: + } + return err + } + return nil + }() + + if err != nil { + errc <- err + return + } + } + }() + + err := <-errc + if err != nil && err == io.EOF { + err = nil + } + return err +} + +type tunTapConn struct { + ifce *water.Interface + addr net.Addr +} + +func (c *tunTapConn) Read(b []byte) (n int, err error) { + return c.ifce.Read(b) +} + +func (c *tunTapConn) Write(b []byte) (n int, err error) { + return c.ifce.Write(b) +} + +func (c *tunTapConn) Close() (err error) { + return c.ifce.Close() +} + +func (c *tunTapConn) LocalAddr() net.Addr { + return c.addr +} + +func (c *tunTapConn) RemoteAddr() net.Addr { + return &net.IPAddr{} +} + +func (c *tunTapConn) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "tuntap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *tunTapConn) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "tuntap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *tunTapConn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "tuntap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +// IsIPv6Multicast reports whether the address addr is an IPv6 multicast address. +func IsIPv6Multicast(addr net.HardwareAddr) bool { + return addr[0] == 0x33 && addr[1] == 0x33 +} diff --git a/tuntap_darwin.go b/tuntap_darwin.go new file mode 100644 index 0000000..ffe2ac7 --- /dev/null +++ b/tuntap_darwin.go @@ -0,0 +1,79 @@ +package gost + +import ( + "errors" + "fmt" + "net" + "os/exec" + "strings" + + "github.com/go-log/log" + "github.com/songgao/water" +) + +func createTun(cfg TunConfig) (conn net.Conn, itf *net.Interface, err error) { + ip, _, err := net.ParseCIDR(cfg.Addr) + if err != nil { + return + } + + ifce, err := water.New(water.Config{ + DeviceType: water.TUN, + }) + if err != nil { + return + } + + mtu := cfg.MTU + if mtu <= 0 { + mtu = DefaultMTU + } + + peer := cfg.Peer + if peer == "" { + peer = ip.String() + } + cmd := fmt.Sprintf("ifconfig %s inet %s %s mtu %d up", + ifce.Name(), cfg.Addr, peer, mtu) + log.Log("[tun]", cmd) + args := strings.Split(cmd, " ") + if er := exec.Command(args[0], args[1:]...).Run(); er != nil { + err = fmt.Errorf("%s: %v", cmd, er) + return + } + + if err = addTunRoutes(ifce.Name(), cfg.Routes...); err != nil { + return + } + + itf, err = net.InterfaceByName(ifce.Name()) + if err != nil { + return + } + + conn = &tunTapConn{ + ifce: ifce, + addr: &net.IPAddr{IP: ip}, + } + return +} + +func createTap(cfg TapConfig) (conn net.Conn, itf *net.Interface, err error) { + err = errors.New("tap is not supported on darwin") + return +} + +func addTunRoutes(ifName string, routes ...IPRoute) error { + for _, route := range routes { + if route.Dest == nil { + continue + } + cmd := fmt.Sprintf("route add -net %s -interface %s", route.Dest.String(), ifName) + log.Log("[tun]", cmd) + args := strings.Split(cmd, " ") + if er := exec.Command(args[0], args[1:]...).Run(); er != nil { + return fmt.Errorf("%s: %v", cmd, er) + } + } + return nil +} diff --git a/tuntap_linux.go b/tuntap_linux.go new file mode 100644 index 0000000..2948773 --- /dev/null +++ b/tuntap_linux.go @@ -0,0 +1,173 @@ +package gost + +import ( + "errors" + "fmt" + "net" + "syscall" + + "github.com/docker/libcontainer/netlink" + "github.com/go-log/log" + "github.com/milosgajdos/tenus" + "github.com/songgao/water" +) + +func createTun(cfg TunConfig) (conn net.Conn, itf *net.Interface, err error) { + ip, ipNet, err := net.ParseCIDR(cfg.Addr) + if err != nil { + return + } + + ifce, err := water.New(water.Config{ + DeviceType: water.TUN, + PlatformSpecificParams: water.PlatformSpecificParams{ + Name: cfg.Name, + }, + }) + if err != nil { + return + } + + link, err := tenus.NewLinkFrom(ifce.Name()) + if err != nil { + return + } + + mtu := cfg.MTU + if mtu <= 0 { + mtu = DefaultMTU + } + + cmd := fmt.Sprintf("ip link set dev %s mtu %d", ifce.Name(), mtu) + log.Log("[tun]", cmd) + if er := link.SetLinkMTU(mtu); er != nil { + err = fmt.Errorf("%s: %v", cmd, er) + return + } + + cmd = fmt.Sprintf("ip address add %s dev %s", cfg.Addr, ifce.Name()) + log.Log("[tun]", cmd) + if er := link.SetLinkIp(ip, ipNet); er != nil { + err = fmt.Errorf("%s: %v", cmd, er) + return + } + + cmd = fmt.Sprintf("ip link set dev %s up", ifce.Name()) + log.Log("[tun]", cmd) + if er := link.SetLinkUp(); er != nil { + err = fmt.Errorf("%s: %v", cmd, er) + return + } + + if err = addTunRoutes(ifce.Name(), cfg.Routes...); err != nil { + return + } + + itf, err = net.InterfaceByName(ifce.Name()) + if err != nil { + return + } + + conn = &tunTapConn{ + ifce: ifce, + addr: &net.IPAddr{IP: ip}, + } + return +} + +func createTap(cfg TapConfig) (conn net.Conn, itf *net.Interface, err error) { + var ip net.IP + var ipNet *net.IPNet + if cfg.Addr != "" { + ip, ipNet, err = net.ParseCIDR(cfg.Addr) + if err != nil { + return + } + } + + ifce, err := water.New(water.Config{ + DeviceType: water.TAP, + PlatformSpecificParams: water.PlatformSpecificParams{ + Name: cfg.Name, + }, + }) + if err != nil { + return + } + + link, err := tenus.NewLinkFrom(ifce.Name()) + if err != nil { + return + } + + mtu := cfg.MTU + if mtu <= 0 { + mtu = DefaultMTU + } + + cmd := fmt.Sprintf("ip link set dev %s mtu %d", ifce.Name(), mtu) + log.Log("[tap]", cmd) + if er := link.SetLinkMTU(mtu); er != nil { + err = fmt.Errorf("%s: %v", cmd, er) + return + } + + if cfg.Addr != "" { + cmd = fmt.Sprintf("ip address add %s dev %s", cfg.Addr, ifce.Name()) + log.Log("[tap]", cmd) + if er := link.SetLinkIp(ip, ipNet); er != nil { + err = fmt.Errorf("%s: %v", cmd, er) + return + } + } + + cmd = fmt.Sprintf("ip link set dev %s up", ifce.Name()) + log.Log("[tap]", cmd) + if er := link.SetLinkUp(); er != nil { + err = fmt.Errorf("%s: %v", cmd, er) + return + } + + if err = addTapRoutes(ifce.Name(), cfg.Gateway, cfg.Routes...); err != nil { + return + } + + itf, err = net.InterfaceByName(ifce.Name()) + if err != nil { + return + } + + conn = &tunTapConn{ + ifce: ifce, + addr: &net.IPAddr{IP: ip}, + } + return +} + +func addTunRoutes(ifName string, routes ...IPRoute) error { + for _, route := range routes { + if route.Dest == nil { + continue + } + cmd := fmt.Sprintf("ip route add %s dev %s", route.Dest.String(), ifName) + log.Logf("[tun] %s", cmd) + if err := netlink.AddRoute(route.Dest.String(), "", "", ifName); err != nil && !errors.Is(err, syscall.EEXIST) { + return fmt.Errorf("%s: %v", cmd, err) + } + } + return nil +} + +func addTapRoutes(ifName string, gw string, routes ...string) error { + for _, route := range routes { + if route == "" { + continue + } + cmd := fmt.Sprintf("ip route add %s via %s dev %s", route, gw, ifName) + log.Logf("[tap] %s", cmd) + if err := netlink.AddRoute(route, "", gw, ifName); err != nil { + return fmt.Errorf("%s: %v", cmd, err) + } + } + return nil +} diff --git a/tuntap_unix.go b/tuntap_unix.go new file mode 100644 index 0000000..04d6db0 --- /dev/null +++ b/tuntap_unix.go @@ -0,0 +1,133 @@ +//go:build !linux && !windows && !darwin +// +build !linux,!windows,!darwin + +package gost + +import ( + "fmt" + "net" + "os/exec" + "strings" + + "github.com/go-log/log" + "github.com/songgao/water" +) + +func createTun(cfg TunConfig) (conn net.Conn, itf *net.Interface, err error) { + ip, _, err := net.ParseCIDR(cfg.Addr) + if err != nil { + return + } + + ifce, err := water.New(water.Config{ + DeviceType: water.TUN, + }) + if err != nil { + return + } + + mtu := cfg.MTU + if mtu <= 0 { + mtu = DefaultMTU + } + + cmd := fmt.Sprintf("ifconfig %s inet %s mtu %d up", ifce.Name(), cfg.Addr, mtu) + log.Log("[tun]", cmd) + args := strings.Split(cmd, " ") + if er := exec.Command(args[0], args[1:]...).Run(); er != nil { + err = fmt.Errorf("%s: %v", cmd, er) + return + } + + if err = addTunRoutes(ifce.Name(), cfg.Routes...); err != nil { + return + } + + itf, err = net.InterfaceByName(ifce.Name()) + if err != nil { + return + } + + conn = &tunTapConn{ + ifce: ifce, + addr: &net.IPAddr{IP: ip}, + } + return +} + +func createTap(cfg TapConfig) (conn net.Conn, itf *net.Interface, err error) { + ip, _, _ := net.ParseCIDR(cfg.Addr) + + ifce, err := water.New(water.Config{ + DeviceType: water.TAP, + }) + if err != nil { + return + } + + mtu := cfg.MTU + if mtu <= 0 { + mtu = DefaultMTU + } + + var cmd string + if cfg.Addr != "" { + cmd = fmt.Sprintf("ifconfig %s inet %s mtu %d up", ifce.Name(), cfg.Addr, mtu) + } else { + cmd = fmt.Sprintf("ifconfig %s mtu %d up", ifce.Name(), mtu) + } + log.Log("[tap]", cmd) + args := strings.Split(cmd, " ") + if er := exec.Command(args[0], args[1:]...).Run(); er != nil { + err = fmt.Errorf("%s: %v", cmd, er) + return + } + + if err = addTapRoutes(ifce.Name(), cfg.Gateway, cfg.Routes...); err != nil { + return + } + + itf, err = net.InterfaceByName(ifce.Name()) + if err != nil { + return + } + + conn = &tunTapConn{ + ifce: ifce, + addr: &net.IPAddr{IP: ip}, + } + return +} + +func addTunRoutes(ifName string, routes ...IPRoute) error { + for _, route := range routes { + if route.Dest == nil { + continue + } + cmd := fmt.Sprintf("route add -net %s -interface %s", route.Dest.String(), ifName) + log.Logf("[tun] %s", cmd) + args := strings.Split(cmd, " ") + if er := exec.Command(args[0], args[1:]...).Run(); er != nil { + return fmt.Errorf("%s: %v", cmd, er) + } + } + return nil +} + +func addTapRoutes(ifName string, gw string, routes ...string) error { + for _, route := range routes { + if route == "" { + continue + } + cmd := fmt.Sprintf("route add -net %s dev %s", route, ifName) + if gw != "" { + cmd += " gw " + gw + } + log.Logf("[tap] %s", cmd) + args := strings.Split(cmd, " ") + if er := exec.Command(args[0], args[1:]...).Run(); er != nil { + return fmt.Errorf("%s: %v", cmd, er) + } + } + return nil +} diff --git a/tuntap_windows.go b/tuntap_windows.go new file mode 100644 index 0000000..dd467aa --- /dev/null +++ b/tuntap_windows.go @@ -0,0 +1,153 @@ +package gost + +import ( + "fmt" + "net" + "os/exec" + "strings" + + "github.com/go-log/log" + "github.com/songgao/water" +) + +func createTun(cfg TunConfig) (conn net.Conn, itf *net.Interface, err error) { + ip, ipNet, err := net.ParseCIDR(cfg.Addr) + if err != nil { + return + } + + ifce, err := water.New(water.Config{ + DeviceType: water.TUN, + PlatformSpecificParams: water.PlatformSpecificParams{ + ComponentID: "tap0901", + InterfaceName: cfg.Name, + Network: cfg.Addr, + }, + }) + if err != nil { + return + } + + cmd := fmt.Sprintf("netsh interface ip set address name=\"%s\" "+ + "source=static addr=%s mask=%s gateway=none", + ifce.Name(), ip.String(), ipMask(ipNet.Mask)) + log.Log("[tun]", cmd) + args := strings.Split(cmd, " ") + if er := exec.Command(args[0], args[1:]...).Run(); er != nil { + err = fmt.Errorf("%s: %v", cmd, er) + return + } + + if err = addTunRoutes(ifce.Name(), cfg.Gateway, cfg.Routes...); err != nil { + return + } + + itf, err = net.InterfaceByName(ifce.Name()) + if err != nil { + return + } + + conn = &tunTapConn{ + ifce: ifce, + addr: &net.IPAddr{IP: ip}, + } + return +} + +func createTap(cfg TapConfig) (conn net.Conn, itf *net.Interface, err error) { + ip, ipNet, _ := net.ParseCIDR(cfg.Addr) + + ifce, err := water.New(water.Config{ + DeviceType: water.TAP, + PlatformSpecificParams: water.PlatformSpecificParams{ + ComponentID: "tap0901", + InterfaceName: cfg.Name, + Network: cfg.Addr, + }, + }) + if err != nil { + return + } + + if ip != nil && ipNet != nil { + cmd := fmt.Sprintf("netsh interface ip set address name=\"%s\" "+ + "source=static addr=%s mask=%s gateway=none", + ifce.Name(), ip.String(), ipMask(ipNet.Mask)) + log.Log("[tap]", cmd) + args := strings.Split(cmd, " ") + if er := exec.Command(args[0], args[1:]...).Run(); er != nil { + err = fmt.Errorf("%s: %v", cmd, er) + return + } + } + + if err = addTapRoutes(ifce.Name(), cfg.Gateway, cfg.Routes...); err != nil { + return + } + + itf, err = net.InterfaceByName(ifce.Name()) + if err != nil { + return + } + + conn = &tunTapConn{ + ifce: ifce, + addr: &net.IPAddr{IP: ip}, + } + return +} + +func addTunRoutes(ifName string, gw string, routes ...IPRoute) error { + for _, route := range routes { + if route.Dest == nil { + continue + } + + deleteRoute(ifName, route.Dest.String()) + + cmd := fmt.Sprintf("netsh interface ip add route prefix=%s interface=\"%s\" store=active", + route.Dest.String(), ifName) + if gw != "" { + cmd += " nexthop=" + gw + } + log.Logf("[tun] %s", cmd) + args := strings.Split(cmd, " ") + if er := exec.Command(args[0], args[1:]...).Run(); er != nil { + return fmt.Errorf("%s: %v", cmd, er) + } + } + return nil +} + +func addTapRoutes(ifName string, gw string, routes ...string) error { + for _, route := range routes { + if route == "" { + continue + } + + deleteRoute(ifName, route) + + cmd := fmt.Sprintf("netsh interface ip add route prefix=%s interface=\"%s\" store=active", + route, ifName) + if gw != "" { + cmd += " nexthop=" + gw + } + log.Logf("[tap] %s", cmd) + args := strings.Split(cmd, " ") + if er := exec.Command(args[0], args[1:]...).Run(); er != nil { + return fmt.Errorf("%s: %v", cmd, er) + } + } + return nil +} + +func deleteRoute(ifName string, route string) error { + cmd := fmt.Sprintf("netsh interface ip delete route prefix=%s interface=\"%s\" store=active", + route, ifName) + args := strings.Split(cmd, " ") + return exec.Command(args[0], args[1:]...).Run() +} + +func ipMask(mask net.IPMask) string { + return fmt.Sprintf("%d.%d.%d.%d", mask[0], mask[1], mask[2], mask[3]) +} diff --git a/udp.go b/udp.go new file mode 100644 index 0000000..c2a71ab --- /dev/null +++ b/udp.go @@ -0,0 +1,356 @@ +package gost + +import ( + "errors" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/go-log/log" +) + +// udpTransporter is a raw UDP transporter. +type udpTransporter struct{} + +// UDPTransporter creates a Transporter for UDP client. +func UDPTransporter() Transporter { + return &udpTransporter{} +} + +func (tr *udpTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) { + taddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + + conn, err := net.DialUDP("udp", nil, taddr) + if err != nil { + return nil, err + } + return &udpClientConn{ + UDPConn: conn, + }, nil +} + +func (tr *udpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + return conn, nil +} + +func (tr *udpTransporter) Multiplex() bool { + return false +} + +// UDPListenConfig is the config for UDP Listener. +type UDPListenConfig struct { + TTL time.Duration // timeout per connection + Backlog int // connection backlog + QueueSize int // recv queue size per connection +} + +type udpListener struct { + ln net.PacketConn + connChan chan net.Conn + errChan chan error + connMap *udpConnMap + config *UDPListenConfig +} + +// UDPListener creates a Listener for UDP server. +func UDPListener(addr string, cfg *UDPListenConfig) (Listener, error) { + laddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + ln, err := net.ListenUDP("udp", laddr) + if err != nil { + return nil, err + } + + if cfg == nil { + cfg = &UDPListenConfig{} + } + + backlog := cfg.Backlog + if backlog <= 0 { + backlog = defaultBacklog + } + + l := &udpListener{ + ln: ln, + connChan: make(chan net.Conn, backlog), + errChan: make(chan error, 1), + connMap: new(udpConnMap), + config: cfg, + } + go l.listenLoop() + return l, nil +} + +func (l *udpListener) listenLoop() { + for { + // NOTE: this buffer will be released in the udpServerConn after read. + b := mPool.Get().([]byte) + + n, raddr, err := l.ln.ReadFrom(b) + if err != nil { + log.Logf("[udp] peer -> %s : %s", l.Addr(), err) + l.Close() + l.errChan <- err + close(l.errChan) + return + } + + conn, ok := l.connMap.Get(raddr.String()) + if !ok { + conn = newUDPServerConn(l.ln, raddr, &udpServerConnConfig{ + ttl: l.config.TTL, + qsize: l.config.QueueSize, + onClose: func() { + l.connMap.Delete(raddr.String()) + log.Logf("[udp] %s closed (%d)", raddr, l.connMap.Size()) + }, + }) + + select { + case l.connChan <- conn: + l.connMap.Set(raddr.String(), conn) + log.Logf("[udp] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size()) + default: + conn.Close() + log.Logf("[udp] %s - %s: connection queue is full (%d)", raddr, l.Addr(), cap(l.connChan)) + } + } + + select { + case conn.rChan <- b[:n]: + if Debug { + log.Logf("[udp] %s >>> %s : length %d", raddr, l.Addr(), n) + } + default: + log.Logf("[udp] %s -> %s : recv queue is full (%d)", raddr, l.Addr(), cap(conn.rChan)) + } + } +} + +func (l *udpListener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.connChan: + case err, ok = <-l.errChan: + if !ok { + err = errors.New("accpet on closed listener") + } + } + return +} + +func (l *udpListener) Addr() net.Addr { + return l.ln.LocalAddr() +} + +func (l *udpListener) Close() error { + err := l.ln.Close() + l.connMap.Range(func(k interface{}, v *udpServerConn) bool { + v.Close() + return true + }) + + return err +} + +type udpConnMap struct { + size int64 + m sync.Map +} + +func (m *udpConnMap) Get(key interface{}) (conn *udpServerConn, ok bool) { + v, ok := m.m.Load(key) + if ok { + conn, ok = v.(*udpServerConn) + } + return +} + +func (m *udpConnMap) Set(key interface{}, conn *udpServerConn) { + m.m.Store(key, conn) + atomic.AddInt64(&m.size, 1) +} + +func (m *udpConnMap) Delete(key interface{}) { + m.m.Delete(key) + atomic.AddInt64(&m.size, -1) +} + +func (m *udpConnMap) Range(f func(key interface{}, value *udpServerConn) bool) { + m.m.Range(func(k, v interface{}) bool { + return f(k, v.(*udpServerConn)) + }) +} + +func (m *udpConnMap) Size() int64 { + return atomic.LoadInt64(&m.size) +} + +// udpServerConn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn. +type udpServerConn struct { + conn net.PacketConn + raddr net.Addr + rChan chan []byte + closed chan struct{} + closeMutex sync.Mutex + nopChan chan int + config *udpServerConnConfig +} + +type udpServerConnConfig struct { + ttl time.Duration + qsize int + onClose func() +} + +func newUDPServerConn(conn net.PacketConn, raddr net.Addr, cfg *udpServerConnConfig) *udpServerConn { + if conn == nil || raddr == nil { + return nil + } + + if cfg == nil { + cfg = &udpServerConnConfig{} + } + qsize := cfg.qsize + if qsize <= 0 { + qsize = defaultQueueSize + } + c := &udpServerConn{ + conn: conn, + raddr: raddr, + rChan: make(chan []byte, qsize), + closed: make(chan struct{}), + nopChan: make(chan int), + config: cfg, + } + go c.ttlWait() + return c +} + +func (c *udpServerConn) Read(b []byte) (n int, err error) { + n, _, err = c.ReadFrom(b) + return +} + +func (c *udpServerConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + select { + case bb := <-c.rChan: + n = copy(b, bb) + if cap(bb) == mediumBufferSize { + mPool.Put(bb[:cap(bb)]) + } + case <-c.closed: + err = errors.New("read from closed connection") + return + } + + select { + case c.nopChan <- n: + default: + } + + addr = c.raddr + + return +} + +func (c *udpServerConn) Write(b []byte) (n int, err error) { + return c.WriteTo(b, c.raddr) +} + +func (c *udpServerConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + n, err = c.conn.WriteTo(b, addr) + + if n > 0 { + if Debug { + log.Logf("[udp] %s <<< %s : length %d", addr, c.LocalAddr(), n) + } + + select { + case c.nopChan <- n: + default: + } + } + + return +} + +func (c *udpServerConn) Close() error { + c.closeMutex.Lock() + defer c.closeMutex.Unlock() + + select { + case <-c.closed: + return errors.New("connection is closed") + default: + if c.config.onClose != nil { + c.config.onClose() + } + close(c.closed) + } + return nil +} + +func (c *udpServerConn) ttlWait() { + ttl := c.config.ttl + if ttl == 0 { + ttl = defaultTTL + } + timer := time.NewTimer(ttl) + defer timer.Stop() + + for { + select { + case <-c.nopChan: + if !timer.Stop() { + <-timer.C + } + timer.Reset(ttl) + case <-timer.C: + c.Close() + return + case <-c.closed: + return + } + } +} + +func (c *udpServerConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *udpServerConn) RemoteAddr() net.Addr { + return c.raddr +} + +func (c *udpServerConn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *udpServerConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *udpServerConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +type udpClientConn struct { + *net.UDPConn +} + +func (c *udpClientConn) WriteTo(b []byte, addr net.Addr) (int, error) { + return c.UDPConn.Write(b) +} + +func (c *udpClientConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + n, err = c.UDPConn.Read(b) + addr = c.RemoteAddr() + return +} diff --git a/ws.go b/ws.go new file mode 100644 index 0000000..7cc1684 --- /dev/null +++ b/ws.go @@ -0,0 +1,815 @@ +package gost + +import ( + "crypto/rand" + "crypto/sha1" + "crypto/tls" + "encoding/base64" + "io" + "net" + "net/http" + "net/http/httputil" + "sync" + "time" + + "net/url" + + "github.com/go-log/log" + "github.com/gorilla/websocket" + smux "github.com/xtaci/smux" +) + +const ( + defaultWSPath = "/ws" +) + +// WSOptions describes the options for websocket. +type WSOptions struct { + ReadBufferSize int + WriteBufferSize int + HandshakeTimeout time.Duration + EnableCompression bool + UserAgent string + Path string + HeaderConfig map[string]string +} + +type wsTransporter struct { + tcpTransporter + options *WSOptions +} + +// WSTransporter creates a Transporter that is used by websocket proxy client. +func WSTransporter(opts *WSOptions) Transporter { + return &wsTransporter{ + options: opts, + } +} + +func (tr *wsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + wsOptions := tr.options + if opts.WSOptions != nil { + wsOptions = opts.WSOptions + } + if wsOptions == nil { + wsOptions = &WSOptions{} + } + + path := wsOptions.Path + if path == "" { + path = defaultWSPath + } + url := url.URL{Scheme: "ws", Host: opts.Host, Path: path} + return websocketClientConn(url.String(), conn, nil, wsOptions) +} + +type mwsTransporter struct { + tcpTransporter + options *WSOptions + sessions map[string]*muxSession + sessionMutex sync.Mutex +} + +// MWSTransporter creates a Transporter that is used by multiplex-websocket proxy client. +func MWSTransporter(opts *WSOptions) Transporter { + return &mwsTransporter{ + options: opts, + sessions: make(map[string]*muxSession), + } +} + +func (tr *mwsTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { + opts := &DialOptions{} + for _, option := range options { + option(opts) + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + session, ok := tr.sessions[addr] + if session != nil && session.IsClosed() { + delete(tr.sessions, addr) + ok = false + } + if !ok { + timeout := opts.Timeout + if timeout <= 0 { + timeout = DialTimeout + } + + if opts.Chain == nil { + conn, err = net.DialTimeout("tcp", addr, timeout) + } else { + conn, err = opts.Chain.Dial(addr) + } + if err != nil { + return + } + session = &muxSession{conn: conn} + tr.sessions[addr] = session + } + return session.conn, nil +} + +func (tr *mwsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = HandshakeTimeout + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + + session, ok := tr.sessions[opts.Addr] + if !ok || session.session == nil { + s, err := tr.initSession(opts.Addr, conn, opts) + if err != nil { + conn.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + session = s + tr.sessions[opts.Addr] = session + } + + cc, err := session.GetConn() + if err != nil { + session.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + return cc, nil +} + +func (tr *mwsTransporter) initSession(addr string, conn net.Conn, opts *HandshakeOptions) (*muxSession, error) { + if opts == nil { + opts = &HandshakeOptions{} + } + wsOptions := tr.options + if opts.WSOptions != nil { + wsOptions = opts.WSOptions + } + if wsOptions == nil { + wsOptions = &WSOptions{} + } + + path := wsOptions.Path + if path == "" { + path = defaultWSPath + } + url := url.URL{Scheme: "ws", Host: opts.Host, Path: path} + conn, err := websocketClientConn(url.String(), conn, nil, wsOptions) + if err != nil { + return nil, err + } + // stream multiplex + smuxConfig := smux.DefaultConfig() + session, err := smux.Client(conn, smuxConfig) + if err != nil { + return nil, err + } + return &muxSession{conn: conn, session: session}, nil +} + +func (tr *mwsTransporter) Multiplex() bool { + return true +} + +type wssTransporter struct { + tcpTransporter + options *WSOptions +} + +// WSSTransporter creates a Transporter that is used by websocket secure proxy client. +func WSSTransporter(opts *WSOptions) Transporter { + return &wssTransporter{ + options: opts, + } +} + +func (tr *wssTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + wsOptions := tr.options + if opts.WSOptions != nil { + wsOptions = opts.WSOptions + } + if wsOptions == nil { + wsOptions = &WSOptions{} + } + + if opts.TLSConfig == nil { + opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} + } + path := wsOptions.Path + if path == "" { + path = defaultWSPath + } + url := url.URL{Scheme: "wss", Host: opts.Host, Path: path} + return websocketClientConn(url.String(), conn, opts.TLSConfig, wsOptions) +} + +type mwssTransporter struct { + tcpTransporter + options *WSOptions + sessions map[string]*muxSession + sessionMutex sync.Mutex +} + +// MWSSTransporter creates a Transporter that is used by multiplex-websocket secure proxy client. +func MWSSTransporter(opts *WSOptions) Transporter { + return &mwssTransporter{ + options: opts, + sessions: make(map[string]*muxSession), + } +} + +func (tr *mwssTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { + opts := &DialOptions{} + for _, option := range options { + option(opts) + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + session, ok := tr.sessions[addr] + if session != nil && session.IsClosed() { + delete(tr.sessions, addr) + ok = false + } + if !ok { + timeout := opts.Timeout + if timeout <= 0 { + timeout = DialTimeout + } + + if opts.Chain == nil { + conn, err = net.DialTimeout("tcp", addr, timeout) + } else { + conn, err = opts.Chain.Dial(addr) + } + if err != nil { + return + } + session = &muxSession{conn: conn} + tr.sessions[addr] = session + } + return session.conn, nil +} + +func (tr *mwssTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { + opts := &HandshakeOptions{} + for _, option := range options { + option(opts) + } + + timeout := opts.Timeout + if timeout <= 0 { + timeout = HandshakeTimeout + } + + tr.sessionMutex.Lock() + defer tr.sessionMutex.Unlock() + + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) + + session, ok := tr.sessions[opts.Addr] + if !ok || session.session == nil { + s, err := tr.initSession(opts.Addr, conn, opts) + if err != nil { + conn.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + session = s + tr.sessions[opts.Addr] = session + } + cc, err := session.GetConn() + if err != nil { + session.Close() + delete(tr.sessions, opts.Addr) + return nil, err + } + return cc, nil +} + +func (tr *mwssTransporter) initSession(addr string, conn net.Conn, opts *HandshakeOptions) (*muxSession, error) { + if opts == nil { + opts = &HandshakeOptions{} + } + wsOptions := tr.options + if opts.WSOptions != nil { + wsOptions = opts.WSOptions + } + if wsOptions == nil { + wsOptions = &WSOptions{} + } + + tlsConfig := opts.TLSConfig + if tlsConfig == nil { + tlsConfig = &tls.Config{InsecureSkipVerify: true} + } + path := wsOptions.Path + if path == "" { + path = defaultWSPath + } + url := url.URL{Scheme: "wss", Host: opts.Host, Path: path} + conn, err := websocketClientConn(url.String(), conn, tlsConfig, wsOptions) + if err != nil { + return nil, err + } + // stream multiplex + smuxConfig := smux.DefaultConfig() + session, err := smux.Client(conn, smuxConfig) + if err != nil { + return nil, err + } + return &muxSession{conn: conn, session: session}, nil +} + +func (tr *mwssTransporter) Multiplex() bool { + return true +} + +type wsListener struct { + addr net.Addr + upgrader *websocket.Upgrader + srv *http.Server + connChan chan net.Conn + errChan chan error +} + +// WSListener creates a Listener for websocket proxy server. +func WSListener(addr string, options *WSOptions) (Listener, error) { + tcpAddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + if options == nil { + options = &WSOptions{} + } + l := &wsListener{ + upgrader: &websocket.Upgrader{ + ReadBufferSize: options.ReadBufferSize, + WriteBufferSize: options.WriteBufferSize, + CheckOrigin: func(r *http.Request) bool { return true }, + EnableCompression: options.EnableCompression, + }, + connChan: make(chan net.Conn, 1024), + errChan: make(chan error, 1), + } + + path := options.Path + if path == "" { + path = defaultWSPath + } + mux := http.NewServeMux() + mux.Handle(path, http.HandlerFunc(l.upgrade)) + l.srv = &http.Server{ + Addr: addr, + Handler: mux, + ReadHeaderTimeout: 30 * time.Second, + } + + ln, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + return nil, err + } + l.addr = ln.Addr() + + go func() { + err := l.srv.Serve(tcpKeepAliveListener{ln}) + if err != nil { + l.errChan <- err + } + close(l.errChan) + }() + select { + case err := <-l.errChan: + return nil, err + default: + } + + return l, nil +} + +func (l *wsListener) upgrade(w http.ResponseWriter, r *http.Request) { + log.Logf("[ws] %s -> %s", r.RemoteAddr, l.addr) + if Debug { + dump, _ := httputil.DumpRequest(r, false) + log.Log(string(dump)) + } + conn, err := l.upgrader.Upgrade(w, r, nil) + if err != nil { + log.Logf("[ws] %s - %s : %s", r.RemoteAddr, l.addr, err) + return + } + select { + case l.connChan <- websocketServerConn(conn): + default: + conn.Close() + log.Logf("[ws] %s - %s: connection queue is full", r.RemoteAddr, l.addr) + } +} + +func (l *wsListener) Accept() (conn net.Conn, err error) { + select { + case conn = <-l.connChan: + case err = <-l.errChan: + } + return +} + +func (l *wsListener) Close() error { + return l.srv.Close() +} + +func (l *wsListener) Addr() net.Addr { + return l.addr +} + +type mwsListener struct { + addr net.Addr + upgrader *websocket.Upgrader + srv *http.Server + connChan chan net.Conn + errChan chan error +} + +// MWSListener creates a Listener for multiplex-websocket proxy server. +func MWSListener(addr string, options *WSOptions) (Listener, error) { + tcpAddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + if options == nil { + options = &WSOptions{} + } + l := &mwsListener{ + upgrader: &websocket.Upgrader{ + ReadBufferSize: options.ReadBufferSize, + WriteBufferSize: options.WriteBufferSize, + CheckOrigin: func(r *http.Request) bool { return true }, + EnableCompression: options.EnableCompression, + }, + connChan: make(chan net.Conn, 1024), + errChan: make(chan error, 1), + } + + path := options.Path + if path == "" { + path = defaultWSPath + } + + mux := http.NewServeMux() + mux.Handle(path, http.HandlerFunc(l.upgrade)) + l.srv = &http.Server{ + Addr: addr, + Handler: mux, + ReadHeaderTimeout: 30 * time.Second, + } + + ln, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + return nil, err + } + l.addr = ln.Addr() + + go func() { + err := l.srv.Serve(tcpKeepAliveListener{ln}) + if err != nil { + l.errChan <- err + } + close(l.errChan) + }() + select { + case err := <-l.errChan: + return nil, err + default: + } + + return l, nil +} + +func (l *mwsListener) upgrade(w http.ResponseWriter, r *http.Request) { + log.Logf("[mws] %s -> %s", r.RemoteAddr, l.addr) + if Debug { + dump, _ := httputil.DumpRequest(r, false) + log.Log(string(dump)) + } + conn, err := l.upgrader.Upgrade(w, r, nil) + if err != nil { + log.Logf("[mws] %s - %s : %s", r.RemoteAddr, l.addr, err) + return + } + + l.mux(websocketServerConn(conn)) +} + +func (l *mwsListener) mux(conn net.Conn) { + smuxConfig := smux.DefaultConfig() + mux, err := smux.Server(conn, smuxConfig) + if err != nil { + log.Logf("[mws] %s - %s : %s", conn.RemoteAddr(), l.Addr(), err) + return + } + defer mux.Close() + + log.Logf("[mws] %s <-> %s", conn.RemoteAddr(), l.Addr()) + defer log.Logf("[mws] %s >-< %s", conn.RemoteAddr(), l.Addr()) + + for { + stream, err := mux.AcceptStream() + if err != nil { + log.Log("[mws] accept stream:", err) + return + } + + cc := &muxStreamConn{Conn: conn, stream: stream} + select { + case l.connChan <- cc: + default: + cc.Close() + log.Logf("[mws] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr()) + } + } +} + +func (l *mwsListener) Accept() (conn net.Conn, err error) { + select { + case conn = <-l.connChan: + case err = <-l.errChan: + } + return +} + +func (l *mwsListener) Close() error { + return l.srv.Close() +} + +func (l *mwsListener) Addr() net.Addr { + return l.addr +} + +type wssListener struct { + *wsListener +} + +// WSSListener creates a Listener for websocket secure proxy server. +func WSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Listener, error) { + tcpAddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + if options == nil { + options = &WSOptions{} + } + l := &wssListener{ + wsListener: &wsListener{ + upgrader: &websocket.Upgrader{ + ReadBufferSize: options.ReadBufferSize, + WriteBufferSize: options.WriteBufferSize, + CheckOrigin: func(r *http.Request) bool { return true }, + EnableCompression: options.EnableCompression, + }, + connChan: make(chan net.Conn, 1024), + errChan: make(chan error, 1), + }, + } + + if tlsConfig == nil { + tlsConfig = DefaultTLSConfig + } + + path := options.Path + if path == "" { + path = defaultWSPath + } + + mux := http.NewServeMux() + mux.Handle(path, http.HandlerFunc(l.upgrade)) + l.srv = &http.Server{ + Addr: addr, + TLSConfig: tlsConfig, + Handler: mux, + ReadHeaderTimeout: 30 * time.Second, + } + + ln, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + return nil, err + } + l.addr = ln.Addr() + + go func() { + err := l.srv.Serve(tls.NewListener(tcpKeepAliveListener{ln}, tlsConfig)) + if err != nil { + l.errChan <- err + } + close(l.errChan) + }() + select { + case err := <-l.errChan: + return nil, err + default: + } + + return l, nil +} + +type mwssListener struct { + *mwsListener +} + +// MWSSListener creates a Listener for multiplex-websocket secure proxy server. +func MWSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Listener, error) { + tcpAddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + if options == nil { + options = &WSOptions{} + } + l := &mwssListener{ + mwsListener: &mwsListener{ + upgrader: &websocket.Upgrader{ + ReadBufferSize: options.ReadBufferSize, + WriteBufferSize: options.WriteBufferSize, + CheckOrigin: func(r *http.Request) bool { return true }, + EnableCompression: options.EnableCompression, + }, + connChan: make(chan net.Conn, 1024), + errChan: make(chan error, 1), + }, + } + + if tlsConfig == nil { + tlsConfig = DefaultTLSConfig + } + + path := options.Path + if path == "" { + path = defaultWSPath + } + + mux := http.NewServeMux() + mux.Handle(path, http.HandlerFunc(l.upgrade)) + l.srv = &http.Server{ + Addr: addr, + TLSConfig: tlsConfig, + Handler: mux, + ReadHeaderTimeout: 30 * time.Second, + } + + ln, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + return nil, err + } + l.addr = ln.Addr() + + go func() { + err := l.srv.Serve(tls.NewListener(tcpKeepAliveListener{ln}, tlsConfig)) + if err != nil { + l.errChan <- err + } + close(l.errChan) + }() + select { + case err := <-l.errChan: + return nil, err + default: + } + + return l, nil +} + +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func computeAcceptKey(challengeKey string) string { + h := sha1.New() + h.Write([]byte(challengeKey)) + h.Write(keyGUID) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +func generateChallengeKey() (string, error) { + p := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, p); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(p), nil +} + +// TODO: due to the concurrency control in the websocket.Conn, +// a data race may be met when using with multiplexing. +// See: https://godoc.org/gopkg.in/gorilla/websocket.v1#hdr-Concurrency +type websocketConn struct { + conn *websocket.Conn + rb []byte +} + +func websocketClientConn(url string, conn net.Conn, tlsConfig *tls.Config, options *WSOptions) (net.Conn, error) { + if options == nil { + options = &WSOptions{} + } + + timeout := options.HandshakeTimeout + if timeout <= 0 { + timeout = HandshakeTimeout + } + + dialer := websocket.Dialer{ + ReadBufferSize: options.ReadBufferSize, + WriteBufferSize: options.WriteBufferSize, + TLSClientConfig: tlsConfig, + HandshakeTimeout: timeout, + EnableCompression: options.EnableCompression, + NetDial: func(net, addr string) (net.Conn, error) { + return conn, nil + }, + } + header := http.Header{} + header.Set("User-Agent", DefaultUserAgent) + if options.UserAgent != "" { + header.Set("User-Agent", options.UserAgent) + } + //Process Header + for k, v := range options.HeaderConfig { + if len(k) > 2 && k[0:2] == "--" { + header.Del(k[2:]) + continue + } + header.Set(k, v) + } + c, resp, err := dialer.Dial(url, header) + if err != nil { + return nil, err + } + resp.Body.Close() + return &websocketConn{conn: c}, nil +} + +func websocketServerConn(conn *websocket.Conn) net.Conn { + // conn.EnableWriteCompression(true) + return &websocketConn{ + conn: conn, + } +} + +func (c *websocketConn) Read(b []byte) (n int, err error) { + if len(c.rb) == 0 { + _, c.rb, err = c.conn.ReadMessage() + } + n = copy(b, c.rb) + c.rb = c.rb[n:] + return +} + +func (c *websocketConn) Write(b []byte) (n int, err error) { + err = c.conn.WriteMessage(websocket.BinaryMessage, b) + n = len(b) + return +} + +func (c *websocketConn) Close() error { + return c.conn.Close() +} + +func (c *websocketConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *websocketConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *websocketConn) SetDeadline(t time.Time) error { + if err := c.SetReadDeadline(t); err != nil { + return err + } + return c.SetWriteDeadline(t) +} +func (c *websocketConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *websocketConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} diff --git a/ws_test.go b/ws_test.go new file mode 100644 index 0000000..11899c9 --- /dev/null +++ b/ws_test.go @@ -0,0 +1,808 @@ +package gost + +import ( + "crypto/rand" + "fmt" + "net/http/httptest" + "net/url" + "testing" +) + +func httpOverWSRoundtrip(targetURL string, data []byte, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := WSListener("", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: HTTPConnector(clientInfo), + Transporter: WSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestHTTPOverWS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range httpProxyTests { + err := httpOverWSRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) + if err == nil { + if tc.errStr != "" { + t.Errorf("#%d should failed with error %s", i, tc.errStr) + } + } else { + if tc.errStr == "" { + t.Errorf("#%d got error %v", i, err) + } + if err.Error() != tc.errStr { + t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) + } + } + } +} + +func BenchmarkHTTPOverWS(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := WSListener("", nil) + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: WSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkHTTPOverWSParallel(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := WSListener("", nil) + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: WSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } + }) +} + +func socks5OverWSRoundtrip(targetURL string, data []byte, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := WSListener("", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS5Connector(clientInfo), + Transporter: WSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS5Handler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS5OverWS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range socks5ProxyTests { + err := socks5OverWSRoundtrip(httpSrv.URL, sendData, + tc.cliUser, + tc.srvUsers, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func socks4OverWSRoundtrip(targetURL string, data []byte) error { + ln, err := WSListener("", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4Connector(), + Transporter: WSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4OverWS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4OverWSRoundtrip(httpSrv.URL, sendData) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func socks4aOverWSRoundtrip(targetURL string, data []byte) error { + ln, err := WSListener("", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4AConnector(), + Transporter: WSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4AOverWS(t *testing.T) { + + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4aOverWSRoundtrip(httpSrv.URL, sendData) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func ssOverWSRoundtrip(targetURL string, data []byte, + clientInfo, serverInfo *url.Userinfo) error { + + ln, err := WSListener("", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: ShadowConnector(clientInfo), + Transporter: WSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: ShadowHandler( + UsersHandlerOption(serverInfo), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSSOverWS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range ssProxyTests { + err := ssOverWSRoundtrip(httpSrv.URL, sendData, + tc.clientCipher, + tc.serverCipher, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func sniOverWSRoundtrip(targetURL string, data []byte, host string) error { + ln, err := WSListener("", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: WSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverWS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverWSRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} + +func wsForwardTunnelRoundtrip(targetURL string, data []byte) error { + ln, err := WSListener("", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: WSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + server.Handler.Init() + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestWSForwardTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := wsForwardTunnelRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} + +func httpOverMWSRoundtrip(targetURL string, data []byte, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := MWSListener("", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: HTTPConnector(clientInfo), + Transporter: MWSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestHTTPOverMWS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range httpProxyTests { + err := httpOverMWSRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) + if err == nil { + if tc.errStr != "" { + t.Errorf("#%d should failed with error %s", i, tc.errStr) + } + } else { + if tc.errStr == "" { + t.Errorf("#%d got error %v", i, err) + } + if err.Error() != tc.errStr { + t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) + } + } + } +} + +func BenchmarkHTTPOverMWS(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := MWSListener("", nil) + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: MWSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkHTTPOverMWSParallel(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := MWSListener("", nil) + if err != nil { + b.Error(err) + } + + b.Log(ln.Addr()) + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: MWSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } + }) +} + +func socks5OverMWSRoundtrip(targetURL string, data []byte, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := MWSListener("", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS5Connector(clientInfo), + Transporter: MWSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS5Handler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS5OverMWS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range socks5ProxyTests { + err := socks5OverMWSRoundtrip(httpSrv.URL, sendData, + tc.cliUser, + tc.srvUsers, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func socks4OverMWSRoundtrip(targetURL string, data []byte) error { + ln, err := MWSListener("", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4Connector(), + Transporter: MWSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4OverMWS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4OverMWSRoundtrip(httpSrv.URL, sendData) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func socks4aOverMWSRoundtrip(targetURL string, data []byte) error { + ln, err := MWSListener("", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4AConnector(), + Transporter: MWSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4AOverMWS(t *testing.T) { + + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4aOverMWSRoundtrip(httpSrv.URL, sendData) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func ssOverMWSRoundtrip(targetURL string, data []byte, + clientInfo, serverInfo *url.Userinfo) error { + + ln, err := MWSListener("", nil) + if err != nil { + return err + } + + client := &Client{ + Connector: ShadowConnector(clientInfo), + Transporter: MWSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: ShadowHandler( + UsersHandlerOption(serverInfo), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSSOverMWS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range ssProxyTests { + err := ssOverMWSRoundtrip(httpSrv.URL, sendData, + tc.clientCipher, + tc.serverCipher, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func sniOverMWSRoundtrip(targetURL string, data []byte, host string) error { + ln, err := MWSListener("", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: MWSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverMWS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverMWSRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} + +func mwsForwardTunnelRoundtrip(targetURL string, data []byte) error { + ln, err := MWSListener("", nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: MWSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + server.Handler.Init() + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestMWSForwardTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := mwsForwardTunnelRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} diff --git a/wss_test.go b/wss_test.go new file mode 100644 index 0000000..a213480 --- /dev/null +++ b/wss_test.go @@ -0,0 +1,809 @@ +package gost + +import ( + "crypto/rand" + "crypto/tls" + "fmt" + "net/http/httptest" + "net/url" + "testing" +) + +func httpOverWSSRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := WSSListener("", tlsConfig, nil) + if err != nil { + return err + } + + client := &Client{ + Connector: HTTPConnector(clientInfo), + Transporter: WSSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestHTTPOverWSS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range httpProxyTests { + err := httpOverWSSRoundtrip(httpSrv.URL, sendData, nil, tc.cliUser, tc.srvUsers) + if err == nil { + if tc.errStr != "" { + t.Errorf("#%d should failed with error %s", i, tc.errStr) + } + } else { + if tc.errStr == "" { + t.Errorf("#%d got error %v", i, err) + } + if err.Error() != tc.errStr { + t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) + } + } + } +} + +func BenchmarkHTTPOverWSS(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := WSSListener("", nil, nil) + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: WSSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkHTTPOverWSSParallel(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := WSSListener("", nil, nil) + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: WSSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } + }) +} + +func socks5OverWSSRoundtrip(targetURL string, data []byte, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := WSSListener("", nil, nil) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS5Connector(clientInfo), + Transporter: WSSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS5Handler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS5OverWSS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range socks5ProxyTests { + err := socks5OverWSSRoundtrip(httpSrv.URL, sendData, + tc.cliUser, + tc.srvUsers, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func socks4OverWSSRoundtrip(targetURL string, data []byte) error { + ln, err := WSSListener("", nil, nil) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4Connector(), + Transporter: WSSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4OverWSS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4OverWSSRoundtrip(httpSrv.URL, sendData) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func socks4aOverWSSRoundtrip(targetURL string, data []byte) error { + ln, err := WSSListener("", nil, nil) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4AConnector(), + Transporter: WSSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4AOverWSS(t *testing.T) { + + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4aOverWSSRoundtrip(httpSrv.URL, sendData) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func ssOverWSSRoundtrip(targetURL string, data []byte, + clientInfo, serverInfo *url.Userinfo) error { + + ln, err := WSSListener("", nil, nil) + if err != nil { + return err + } + + client := &Client{ + Connector: ShadowConnector(clientInfo), + Transporter: WSSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: ShadowHandler( + UsersHandlerOption(serverInfo), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSSOverWSS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range ssProxyTests { + err := ssOverWSSRoundtrip(httpSrv.URL, sendData, + tc.clientCipher, + tc.serverCipher, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func sniOverWSSRoundtrip(targetURL string, data []byte, host string) error { + ln, err := WSSListener("", nil, nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: WSSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverWSS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverWSSRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} + +func wssForwardTunnelRoundtrip(targetURL string, data []byte) error { + ln, err := WSSListener("", nil, nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: WSSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + server.Handler.Init() + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestWSSForwardTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := wssForwardTunnelRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +} + +func httpOverMWSSRoundtrip(targetURL string, data []byte, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := MWSSListener("", nil, nil) + if err != nil { + return err + } + + client := &Client{ + Connector: HTTPConnector(clientInfo), + Transporter: MWSSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestHTTPOverMWSS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range httpProxyTests { + err := httpOverMWSSRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) + if err == nil { + if tc.errStr != "" { + t.Errorf("#%d should failed with error %s", i, tc.errStr) + } + } else { + if tc.errStr == "" { + t.Errorf("#%d got error %v", i, err) + } + if err.Error() != tc.errStr { + t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) + } + } + } +} + +func BenchmarkHTTPOverMWSS(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := MWSSListener("", nil, nil) + if err != nil { + b.Error(err) + } + + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: MWSSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + for i := 0; i < b.N; i++ { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } +} + +func BenchmarkHTTPOverMWSSParallel(b *testing.B) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + ln, err := MWSSListener("", nil, nil) + if err != nil { + b.Error(err) + } + + b.Log(ln.Addr()) + client := &Client{ + Connector: HTTPConnector(url.UserPassword("admin", "123456")), + Transporter: MWSSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: HTTPHandler( + UsersHandlerOption(url.UserPassword("admin", "123456")), + ), + } + go server.Run() + defer server.Close() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { + b.Error(err) + } + } + }) +} + +func socks5OverMWSSRoundtrip(targetURL string, data []byte, + clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { + + ln, err := MWSSListener("", nil, nil) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS5Connector(clientInfo), + Transporter: MWSSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS5Handler( + UsersHandlerOption(serverInfo...), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS5OverMWSS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range socks5ProxyTests { + err := socks5OverMWSSRoundtrip(httpSrv.URL, sendData, + tc.cliUser, + tc.srvUsers, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func socks4OverMWSSRoundtrip(targetURL string, data []byte) error { + ln, err := MWSSListener("", nil, nil) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4Connector(), + Transporter: MWSSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4OverMWSS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4OverMWSSRoundtrip(httpSrv.URL, sendData) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func socks4aOverMWSSRoundtrip(targetURL string, data []byte) error { + ln, err := MWSSListener("", nil, nil) + if err != nil { + return err + } + + client := &Client{ + Connector: SOCKS4AConnector(), + Transporter: MWSSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SOCKS4Handler(), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSOCKS4AOverMWSS(t *testing.T) { + + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := socks4aOverMWSSRoundtrip(httpSrv.URL, sendData) + // t.Logf("#%d %v", i, err) + if err != nil { + t.Errorf("got error: %v", err) + } +} + +func ssOverMWSSRoundtrip(targetURL string, data []byte, + clientInfo, serverInfo *url.Userinfo) error { + + ln, err := MWSSListener("", nil, nil) + if err != nil { + return err + } + + client := &Client{ + Connector: ShadowConnector(clientInfo), + Transporter: MWSSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: ShadowHandler( + UsersHandlerOption(serverInfo), + ), + } + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestSSOverMWSS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + for i, tc := range ssProxyTests { + err := ssOverMWSSRoundtrip(httpSrv.URL, sendData, + tc.clientCipher, + tc.serverCipher, + ) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + } +} + +func sniOverMWSSRoundtrip(targetURL string, data []byte, host string) error { + ln, err := MWSSListener("", nil, nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: SNIConnector(host), + Transporter: MWSSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: SNIHandler(HostHandlerOption(u.Host)), + } + + go server.Run() + defer server.Close() + + return sniRoundtrip(client, server, targetURL, data) +} + +func TestSNIOverMWSS(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + httpsSrv := httptest.NewTLSServer(httpTestHandler) + defer httpsSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + var sniProxyTests = []struct { + targetURL string + host string + pass bool + }{ + {httpSrv.URL, "", true}, + {httpSrv.URL, "example.com", true}, + {httpsSrv.URL, "", true}, + {httpsSrv.URL, "example.com", true}, + } + + for i, tc := range sniProxyTests { + tc := tc + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + err := sniOverMWSSRoundtrip(tc.targetURL, sendData, tc.host) + if err == nil { + if !tc.pass { + t.Errorf("#%d should failed", i) + } + } else { + // t.Logf("#%d %v", i, err) + if tc.pass { + t.Errorf("#%d got error: %v", i, err) + } + } + }) + } +} + +func mwssForwardTunnelRoundtrip(targetURL string, data []byte) error { + ln, err := MWSSListener("", nil, nil) + if err != nil { + return err + } + + u, err := url.Parse(targetURL) + if err != nil { + return err + } + + client := &Client{ + Connector: ForwardConnector(), + Transporter: MWSSTransporter(nil), + } + + server := &Server{ + Listener: ln, + Handler: TCPDirectForwardHandler(u.Host), + } + server.Handler.Init() + + go server.Run() + defer server.Close() + + return proxyRoundtrip(client, server, targetURL, data) +} + +func TestMWSSForwardTunnel(t *testing.T) { + httpSrv := httptest.NewServer(httpTestHandler) + defer httpSrv.Close() + + sendData := make([]byte, 128) + rand.Read(sendData) + + err := mwssForwardTunnelRoundtrip(httpSrv.URL, sendData) + if err != nil { + t.Error(err) + } +}