337 lines
8.2 KiB
Go
337 lines
8.2 KiB
Go
|
|
package starssh
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bufio"
|
||
|
|
"context"
|
||
|
|
"encoding/base64"
|
||
|
|
"errors"
|
||
|
|
"fmt"
|
||
|
|
"io"
|
||
|
|
"net"
|
||
|
|
"net/http"
|
||
|
|
"strconv"
|
||
|
|
"strings"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
type bufferedConn struct {
|
||
|
|
net.Conn
|
||
|
|
reader *bufio.Reader
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *bufferedConn) Read(p []byte) (int, error) {
|
||
|
|
if c == nil || c.reader == nil {
|
||
|
|
return 0, io.EOF
|
||
|
|
}
|
||
|
|
return c.reader.Read(p)
|
||
|
|
}
|
||
|
|
|
||
|
|
func resolveDialContext(info LoginInput) DialContextFunc {
|
||
|
|
if info.DialContext != nil {
|
||
|
|
return info.DialContext
|
||
|
|
}
|
||
|
|
|
||
|
|
dialer := &net.Dialer{
|
||
|
|
Timeout: info.Timeout,
|
||
|
|
}
|
||
|
|
return dialer.DialContext
|
||
|
|
}
|
||
|
|
|
||
|
|
func dialTargetConn(ctx context.Context, info LoginInput) (net.Conn, *StarSSH, error) {
|
||
|
|
targetAddr := joinHostPort(info.Addr, info.Port)
|
||
|
|
if info.Jump != nil {
|
||
|
|
return dialViaJump(ctx, info, targetAddr)
|
||
|
|
}
|
||
|
|
|
||
|
|
dialContext := resolveDialContext(info)
|
||
|
|
proxyConfig := normalizeProxyConfig(info.Proxy, info.Timeout)
|
||
|
|
if proxyConfig != nil {
|
||
|
|
return dialViaProxy(ctx, dialContext, *proxyConfig, targetAddr)
|
||
|
|
}
|
||
|
|
|
||
|
|
conn, err := dialContext(ctx, "tcp", targetAddr)
|
||
|
|
return conn, nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
func dialViaJump(ctx context.Context, info LoginInput, targetAddr string) (net.Conn, *StarSSH, error) {
|
||
|
|
if info.Jump == nil {
|
||
|
|
return nil, nil, errors.New("jump login info is nil")
|
||
|
|
}
|
||
|
|
|
||
|
|
jumpClient, err := loginWithContext(ctx, *info.Jump)
|
||
|
|
if err != nil {
|
||
|
|
return nil, nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
conn, err := jumpClient.dialTCPContext(ctx, "tcp", targetAddr, jumpClient.Close)
|
||
|
|
if err != nil {
|
||
|
|
_ = jumpClient.Close()
|
||
|
|
return nil, nil, err
|
||
|
|
}
|
||
|
|
return conn, jumpClient, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func dialViaProxy(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, *StarSSH, error) {
|
||
|
|
if dialContext == nil {
|
||
|
|
return nil, nil, errors.New("dial context is nil")
|
||
|
|
}
|
||
|
|
if strings.TrimSpace(proxy.Addr) == "" {
|
||
|
|
return nil, nil, errors.New("proxy address is empty")
|
||
|
|
}
|
||
|
|
|
||
|
|
switch proxy.Type {
|
||
|
|
case ProxyTypeSOCKS5:
|
||
|
|
conn, err := dialSOCKS5(ctx, dialContext, proxy, targetAddr)
|
||
|
|
return conn, nil, err
|
||
|
|
case ProxyTypeHTTPConnect:
|
||
|
|
conn, err := dialHTTPConnect(ctx, dialContext, proxy, targetAddr)
|
||
|
|
return conn, nil, err
|
||
|
|
default:
|
||
|
|
return nil, nil, fmt.Errorf("unsupported proxy type %q", proxy.Type)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func dialHTTPConnect(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, error) {
|
||
|
|
conn, err := dialContext(ctx, "tcp", proxy.Addr)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
restoreDeadline := applyConnDeadline(conn, ctx, proxy.Timeout)
|
||
|
|
defer restoreDeadline()
|
||
|
|
|
||
|
|
request := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n", targetAddr, targetAddr)
|
||
|
|
if proxy.Username != "" || proxy.Password != "" {
|
||
|
|
token := base64.StdEncoding.EncodeToString([]byte(proxy.Username + ":" + proxy.Password))
|
||
|
|
request += "Proxy-Authorization: Basic " + token + "\r\n"
|
||
|
|
}
|
||
|
|
request += "\r\n"
|
||
|
|
|
||
|
|
if _, err := io.WriteString(conn, request); err != nil {
|
||
|
|
_ = conn.Close()
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
reader := bufio.NewReader(conn)
|
||
|
|
response, err := http.ReadResponse(reader, &http.Request{Method: http.MethodConnect})
|
||
|
|
if err != nil {
|
||
|
|
_ = conn.Close()
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
defer response.Body.Close()
|
||
|
|
|
||
|
|
if response.StatusCode < 200 || response.StatusCode >= 300 {
|
||
|
|
_, _ = io.Copy(io.Discard, io.LimitReader(response.Body, 1024))
|
||
|
|
_ = conn.Close()
|
||
|
|
return nil, fmt.Errorf("http CONNECT proxy rejected target %s: %s", targetAddr, response.Status)
|
||
|
|
}
|
||
|
|
|
||
|
|
if reader.Buffered() == 0 {
|
||
|
|
return conn, nil
|
||
|
|
}
|
||
|
|
return &bufferedConn{Conn: conn, reader: reader}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func dialSOCKS5(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, error) {
|
||
|
|
conn, err := dialContext(ctx, "tcp", proxy.Addr)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
restoreDeadline := applyConnDeadline(conn, ctx, proxy.Timeout)
|
||
|
|
defer restoreDeadline()
|
||
|
|
|
||
|
|
methods := []byte{0x00}
|
||
|
|
useAuth := proxy.Username != "" || proxy.Password != ""
|
||
|
|
if useAuth {
|
||
|
|
methods = append(methods, 0x02)
|
||
|
|
}
|
||
|
|
|
||
|
|
hello := append([]byte{0x05, byte(len(methods))}, methods...)
|
||
|
|
if _, err := conn.Write(hello); err != nil {
|
||
|
|
_ = conn.Close()
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
response := make([]byte, 2)
|
||
|
|
if _, err := io.ReadFull(conn, response); err != nil {
|
||
|
|
_ = conn.Close()
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
if response[0] != 0x05 {
|
||
|
|
_ = conn.Close()
|
||
|
|
return nil, fmt.Errorf("invalid socks5 version %d", response[0])
|
||
|
|
}
|
||
|
|
if response[1] == 0xFF {
|
||
|
|
_ = conn.Close()
|
||
|
|
return nil, errors.New("socks5 proxy has no acceptable auth method")
|
||
|
|
}
|
||
|
|
|
||
|
|
if response[1] == 0x02 {
|
||
|
|
if err := writeSOCKS5UserPassAuth(conn, proxy.Username, proxy.Password); err != nil {
|
||
|
|
_ = conn.Close()
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := writeSOCKS5Connect(conn, targetAddr); err != nil {
|
||
|
|
_ = conn.Close()
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := readSOCKS5ConnectResponse(conn); err != nil {
|
||
|
|
_ = conn.Close()
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
return conn, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func writeSOCKS5UserPassAuth(conn net.Conn, username string, password string) error {
|
||
|
|
if len(username) > 255 || len(password) > 255 {
|
||
|
|
return errors.New("socks5 username/password too long")
|
||
|
|
}
|
||
|
|
|
||
|
|
request := make([]byte, 0, 3+len(username)+len(password))
|
||
|
|
request = append(request, 0x01, byte(len(username)))
|
||
|
|
request = append(request, []byte(username)...)
|
||
|
|
request = append(request, byte(len(password)))
|
||
|
|
request = append(request, []byte(password)...)
|
||
|
|
|
||
|
|
if _, err := conn.Write(request); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
response := make([]byte, 2)
|
||
|
|
if _, err := io.ReadFull(conn, response); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if response[1] != 0x00 {
|
||
|
|
return errors.New("socks5 username/password authentication failed")
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func writeSOCKS5Connect(conn net.Conn, targetAddr string) error {
|
||
|
|
host, portString, err := net.SplitHostPort(targetAddr)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
port, err := strconv.Atoi(portString)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if port < 0 || port > 65535 {
|
||
|
|
return fmt.Errorf("invalid port %d", port)
|
||
|
|
}
|
||
|
|
|
||
|
|
request := []byte{0x05, 0x01, 0x00}
|
||
|
|
if ip := net.ParseIP(host); ip != nil {
|
||
|
|
if ip4 := ip.To4(); ip4 != nil {
|
||
|
|
request = append(request, 0x01)
|
||
|
|
request = append(request, ip4...)
|
||
|
|
} else {
|
||
|
|
request = append(request, 0x04)
|
||
|
|
request = append(request, ip.To16()...)
|
||
|
|
}
|
||
|
|
} else {
|
||
|
|
if len(host) > 255 {
|
||
|
|
return errors.New("socks5 target host too long")
|
||
|
|
}
|
||
|
|
request = append(request, 0x03, byte(len(host)))
|
||
|
|
request = append(request, []byte(host)...)
|
||
|
|
}
|
||
|
|
|
||
|
|
request = append(request, byte(port>>8), byte(port))
|
||
|
|
_, err = conn.Write(request)
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
func readSOCKS5ConnectResponse(conn net.Conn) error {
|
||
|
|
header := make([]byte, 4)
|
||
|
|
if _, err := io.ReadFull(conn, header); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if header[0] != 0x05 {
|
||
|
|
return fmt.Errorf("invalid socks5 response version %d", header[0])
|
||
|
|
}
|
||
|
|
if header[1] != 0x00 {
|
||
|
|
return fmt.Errorf("socks5 connect failed with code %d", header[1])
|
||
|
|
}
|
||
|
|
|
||
|
|
switch header[3] {
|
||
|
|
case 0x01:
|
||
|
|
if _, err := io.ReadFull(conn, make([]byte, 4)); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
case 0x03:
|
||
|
|
size := make([]byte, 1)
|
||
|
|
if _, err := io.ReadFull(conn, size); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if _, err := io.ReadFull(conn, make([]byte, int(size[0]))); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
case 0x04:
|
||
|
|
if _, err := io.ReadFull(conn, make([]byte, 16)); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
default:
|
||
|
|
return fmt.Errorf("unsupported socks5 bind address type %d", header[3])
|
||
|
|
}
|
||
|
|
|
||
|
|
_, err := io.ReadFull(conn, make([]byte, 2))
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
func applyConnDeadline(conn net.Conn, ctx context.Context, timeout time.Duration) func() {
|
||
|
|
if conn == nil {
|
||
|
|
return func() {}
|
||
|
|
}
|
||
|
|
|
||
|
|
var (
|
||
|
|
deadline time.Time
|
||
|
|
hasValue bool
|
||
|
|
)
|
||
|
|
if ctx != nil {
|
||
|
|
if ctxDeadline, ok := ctx.Deadline(); ok {
|
||
|
|
deadline = ctxDeadline
|
||
|
|
hasValue = true
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if timeout > 0 {
|
||
|
|
timeoutDeadline := time.Now().Add(timeout)
|
||
|
|
if !hasValue || timeoutDeadline.Before(deadline) {
|
||
|
|
deadline = timeoutDeadline
|
||
|
|
hasValue = true
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if !hasValue {
|
||
|
|
return func() {}
|
||
|
|
}
|
||
|
|
|
||
|
|
_ = conn.SetDeadline(deadline)
|
||
|
|
return func() {
|
||
|
|
_ = conn.SetDeadline(time.Time{})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func normalizeProxyConfig(proxy *ProxyConfig, defaultTimeout time.Duration) *ProxyConfig {
|
||
|
|
if proxy == nil {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
normalized := *proxy
|
||
|
|
normalized.Type = ProxyType(strings.ToLower(strings.TrimSpace(string(normalized.Type))))
|
||
|
|
normalized.Addr = strings.TrimSpace(normalized.Addr)
|
||
|
|
if normalized.Timeout <= 0 {
|
||
|
|
normalized.Timeout = defaultTimeout
|
||
|
|
}
|
||
|
|
return &normalized
|
||
|
|
}
|
||
|
|
|
||
|
|
func joinHostPort(host string, port int) string {
|
||
|
|
return net.JoinHostPort(strings.TrimSpace(host), strconv.Itoa(port))
|
||
|
|
}
|