update config parsing
This commit is contained in:
@ -1,5 +1,9 @@
|
||||
package chain
|
||||
|
||||
type Chainer interface {
|
||||
Route(network, address string) *Route
|
||||
}
|
||||
|
||||
type Chain struct {
|
||||
groups []*NodeGroup
|
||||
}
|
||||
@ -8,16 +12,12 @@ func (c *Chain) AddNodeGroup(group *NodeGroup) {
|
||||
c.groups = append(c.groups, group)
|
||||
}
|
||||
|
||||
func (c *Chain) GetRoute() (r *route) {
|
||||
return c.GetRouteFor("tcp", "")
|
||||
}
|
||||
|
||||
func (c *Chain) GetRouteFor(network, address string) (r *route) {
|
||||
func (c *Chain) Route(network, address string) (r *Route) {
|
||||
if c == nil || len(c.groups) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
r = &route{}
|
||||
r = &Route{}
|
||||
for _, group := range c.groups {
|
||||
node := group.Next()
|
||||
if node == nil {
|
||||
@ -32,14 +32,10 @@ func (c *Chain) GetRouteFor(network, address string) (r *route) {
|
||||
WithRoute(r)
|
||||
node = node.Copy()
|
||||
node.Transport = tr
|
||||
r = &route{}
|
||||
r = &Route{}
|
||||
}
|
||||
|
||||
r.AddNode(node)
|
||||
r.addNode(node)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (c *Chain) IsEmpty() bool {
|
||||
return c == nil || len(c.groups) == 0
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
"github.com/go-gost/gost/pkg/resolver"
|
||||
)
|
||||
|
||||
func resolve(ctx context.Context, addr string, resolver resolver.Resolver, hosts hosts.HostMapper, log logger.Logger) (string, error) {
|
||||
func resolve(ctx context.Context, network, addr string, resolver resolver.Resolver, hosts hosts.HostMapper, log logger.Logger) (string, error) {
|
||||
if addr == "" {
|
||||
return addr, nil
|
||||
}
|
||||
@ -24,14 +24,14 @@ func resolve(ctx context.Context, addr string, resolver resolver.Resolver, hosts
|
||||
}
|
||||
|
||||
if hosts != nil {
|
||||
if ips, _ := hosts.Lookup("ip", host); len(ips) > 0 {
|
||||
if ips, _ := hosts.Lookup(network, host); len(ips) > 0 {
|
||||
log.Debugf("hit host mapper: %s -> %s", host, ips)
|
||||
return net.JoinHostPort(ips[0].String(), port), nil
|
||||
}
|
||||
}
|
||||
|
||||
if resolver != nil {
|
||||
ips, err := resolver.Resolve(ctx, host)
|
||||
ips, err := resolver.Resolve(ctx, network, host)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
@ -15,17 +15,17 @@ var (
|
||||
ErrEmptyRoute = errors.New("empty route")
|
||||
)
|
||||
|
||||
type route struct {
|
||||
type Route struct {
|
||||
nodes []*Node
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func (r *route) AddNode(node *Node) {
|
||||
func (r *Route) addNode(node *Node) {
|
||||
r.nodes = append(r.nodes, node)
|
||||
}
|
||||
|
||||
func (r *route) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if r.IsEmpty() {
|
||||
func (r *Route) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
if r.Len() == 0 {
|
||||
return r.dialDirect(ctx, network, address)
|
||||
}
|
||||
|
||||
@ -34,7 +34,7 @@ func (r *route) Dial(ctx context.Context, network, address string) (net.Conn, er
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cc, err := r.Last().Transport.Connect(ctx, conn, network, address)
|
||||
cc, err := r.GetNode(r.Len()-1).Transport.Connect(ctx, conn, network, address)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
@ -42,7 +42,7 @@ func (r *route) Dial(ctx context.Context, network, address string) (net.Conn, er
|
||||
return cc, nil
|
||||
}
|
||||
|
||||
func (r *route) dialDirect(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
func (r *Route) dialDirect(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
switch network {
|
||||
case "udp", "udp4", "udp6":
|
||||
if address == "" {
|
||||
@ -55,8 +55,8 @@ func (r *route) dialDirect(ctx context.Context, network, address string) (net.Co
|
||||
return d.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
func (r *route) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) {
|
||||
if r.IsEmpty() {
|
||||
func (r *Route) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) {
|
||||
if r.Len() == 0 {
|
||||
return r.bindLocal(ctx, network, address, opts...)
|
||||
}
|
||||
|
||||
@ -65,7 +65,7 @@ func (r *route) Bind(ctx context.Context, network, address string, opts ...conne
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ln, err := r.Last().Transport.Bind(ctx, conn, network, address, opts...)
|
||||
ln, err := r.GetNode(r.Len()-1).Transport.Bind(ctx, conn, network, address, opts...)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
@ -74,14 +74,15 @@ func (r *route) Bind(ctx context.Context, network, address string, opts ...conne
|
||||
return ln, nil
|
||||
}
|
||||
|
||||
func (r *route) connect(ctx context.Context) (conn net.Conn, err error) {
|
||||
if r.IsEmpty() {
|
||||
func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) {
|
||||
if r.Len() == 0 {
|
||||
return nil, ErrEmptyRoute
|
||||
}
|
||||
|
||||
network := "ip"
|
||||
node := r.nodes[0]
|
||||
|
||||
addr, err := resolve(ctx, node.Addr, node.Resolver, node.Hosts, r.logger)
|
||||
addr, err := resolve(ctx, network, node.Addr, node.Resolver, node.Hosts, r.logger)
|
||||
if err != nil {
|
||||
node.Marker.Mark()
|
||||
return
|
||||
@ -102,7 +103,7 @@ func (r *route) connect(ctx context.Context) (conn net.Conn, err error) {
|
||||
|
||||
preNode := node
|
||||
for _, node := range r.nodes[1:] {
|
||||
addr, err = resolve(ctx, node.Addr, node.Resolver, node.Hosts, r.logger)
|
||||
addr, err = resolve(ctx, network, node.Addr, node.Resolver, node.Hosts, r.logger)
|
||||
if err != nil {
|
||||
cn.Close()
|
||||
node.Marker.Mark()
|
||||
@ -130,18 +131,21 @@ func (r *route) connect(ctx context.Context) (conn net.Conn, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (r *route) IsEmpty() bool {
|
||||
return r == nil || len(r.nodes) == 0
|
||||
func (r *Route) Len() int {
|
||||
if r == nil {
|
||||
return 0
|
||||
}
|
||||
return len(r.nodes)
|
||||
}
|
||||
|
||||
func (r *route) Last() *Node {
|
||||
if r.IsEmpty() {
|
||||
func (r *Route) GetNode(index int) *Node {
|
||||
if r.Len() == 0 || index < 0 || index >= len(r.nodes) {
|
||||
return nil
|
||||
}
|
||||
return r.nodes[len(r.nodes)-1]
|
||||
return r.nodes[index]
|
||||
}
|
||||
|
||||
func (r *route) Path() (path []*Node) {
|
||||
func (r *Route) Path() (path []*Node) {
|
||||
if r == nil || len(r.nodes) == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -155,7 +159,7 @@ func (r *route) Path() (path []*Node) {
|
||||
return
|
||||
}
|
||||
|
||||
func (r *route) bindLocal(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) {
|
||||
func (r *Route) bindLocal(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) {
|
||||
options := connector.BindOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
|
@ -14,7 +14,7 @@ import (
|
||||
|
||||
type Router struct {
|
||||
Retries int
|
||||
Chain *Chain
|
||||
Chain Chainer
|
||||
Hosts hosts.HostMapper
|
||||
Resolver resolver.Resolver
|
||||
Logger logger.Logger
|
||||
@ -41,7 +41,10 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co
|
||||
r.Logger.Debugf("dial %s/%s", address, network)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
route := r.Chain.GetRouteFor(network, address)
|
||||
var route *Route
|
||||
if r.Chain != nil {
|
||||
route = r.Chain.Route(network, address)
|
||||
}
|
||||
|
||||
if r.Logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
buf := bytes.Buffer{}
|
||||
@ -52,7 +55,7 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co
|
||||
r.Logger.Debugf("route(retry=%d) %s", i, buf.String())
|
||||
}
|
||||
|
||||
address, err = resolve(ctx, address, r.Resolver, r.Hosts, r.Logger)
|
||||
address, err = resolve(ctx, "ip", address, r.Resolver, r.Hosts, r.Logger)
|
||||
if err != nil {
|
||||
r.Logger.Error(err)
|
||||
break
|
||||
@ -80,7 +83,10 @@ func (r *Router) Bind(ctx context.Context, network, address string, opts ...conn
|
||||
r.Logger.Debugf("bind on %s/%s", address, network)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
route := r.Chain.GetRouteFor(network, address)
|
||||
var route *Route
|
||||
if r.Chain != nil {
|
||||
route = r.Chain.Route(network, address)
|
||||
}
|
||||
|
||||
if r.Logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
buf := bytes.Buffer{}
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
|
||||
type Transport struct {
|
||||
addr string
|
||||
route *route
|
||||
route *Route
|
||||
dialer dialer.Dialer
|
||||
connector connector.Connector
|
||||
}
|
||||
@ -39,7 +39,7 @@ func (tr *Transport) dialOptions() []dialer.DialOption {
|
||||
opts := []dialer.DialOption{
|
||||
dialer.HostDialOption(tr.addr),
|
||||
}
|
||||
if !tr.route.IsEmpty() {
|
||||
if tr.route.Len() > 0 {
|
||||
opts = append(opts,
|
||||
dialer.DialFuncDialOption(
|
||||
func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
@ -84,7 +84,7 @@ func (tr *Transport) Multiplex() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (tr *Transport) WithRoute(r *route) *Transport {
|
||||
func (tr *Transport) WithRoute(r *Route) *Transport {
|
||||
tr.route = r
|
||||
return tr
|
||||
}
|
||||
|
Reference in New Issue
Block a user