package starssh import ( "bufio" "context" "encoding/base64" "errors" "fmt" "io" "net" "net/http" "strconv" "strings" "time" ) type bufferedConn struct { net.Conn reader *bufio.Reader } func (c *bufferedConn) Read(p []byte) (int, error) { if c == nil || c.reader == nil { return 0, io.EOF } return c.reader.Read(p) } func resolveDialContext(info LoginInput) DialContextFunc { if info.DialContext != nil { return info.DialContext } dialer := &net.Dialer{ Timeout: effectiveDialTimeout(info), } return dialer.DialContext } func dialTargetConn(ctx context.Context, info LoginInput) (net.Conn, *StarSSH, error) { targetAddr := joinHostPort(info.Addr, info.Port) if info.Jump != nil { return dialViaJump(ctx, info, targetAddr) } dialContext := resolveDialContext(info) proxyConfig := normalizeProxyConfig(info.Proxy, effectiveDialTimeout(info)) if proxyConfig != nil { return dialViaProxy(ctx, dialContext, *proxyConfig, targetAddr) } conn, err := dialContext(ctx, "tcp", targetAddr) return conn, nil, err } func dialViaJump(ctx context.Context, info LoginInput, targetAddr string) (net.Conn, *StarSSH, error) { if info.Jump == nil { return nil, nil, errors.New("jump login info is nil") } jumpClient, err := loginWithContext(ctx, *info.Jump) if err != nil { return nil, nil, err } conn, err := jumpClient.dialTCPContext(ctx, "tcp", targetAddr, jumpClient.Close) if err != nil { _ = jumpClient.Close() return nil, nil, err } return conn, jumpClient, nil } func dialViaProxy(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, *StarSSH, error) { if dialContext == nil { return nil, nil, errors.New("dial context is nil") } if strings.TrimSpace(proxy.Addr) == "" { return nil, nil, errors.New("proxy address is empty") } switch proxy.Type { case ProxyTypeSOCKS5: conn, err := dialSOCKS5(ctx, dialContext, proxy, targetAddr) return conn, nil, err case ProxyTypeHTTPConnect: conn, err := dialHTTPConnect(ctx, dialContext, proxy, targetAddr) return conn, nil, err default: return nil, nil, fmt.Errorf("unsupported proxy type %q", proxy.Type) } } func dialHTTPConnect(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, error) { conn, err := dialContext(ctx, "tcp", proxy.Addr) if err != nil { return nil, err } restoreDeadline := applyConnDeadline(conn, ctx, proxy.Timeout) defer restoreDeadline() request := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n", targetAddr, targetAddr) if proxy.Username != "" || proxy.Password != "" { token := base64.StdEncoding.EncodeToString([]byte(proxy.Username + ":" + proxy.Password)) request += "Proxy-Authorization: Basic " + token + "\r\n" } request += "\r\n" if _, err := io.WriteString(conn, request); err != nil { _ = conn.Close() return nil, err } reader := bufio.NewReader(conn) response, err := http.ReadResponse(reader, &http.Request{Method: http.MethodConnect}) if err != nil { _ = conn.Close() return nil, err } defer response.Body.Close() if response.StatusCode < 200 || response.StatusCode >= 300 { _, _ = io.Copy(io.Discard, io.LimitReader(response.Body, 1024)) _ = conn.Close() return nil, fmt.Errorf("http CONNECT proxy rejected target %s: %s", targetAddr, response.Status) } if reader.Buffered() == 0 { return conn, nil } return &bufferedConn{Conn: conn, reader: reader}, nil } func dialSOCKS5(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, error) { conn, err := dialContext(ctx, "tcp", proxy.Addr) if err != nil { return nil, err } restoreDeadline := applyConnDeadline(conn, ctx, proxy.Timeout) defer restoreDeadline() methods := []byte{0x00} useAuth := proxy.Username != "" || proxy.Password != "" if useAuth { methods = append(methods, 0x02) } hello := append([]byte{0x05, byte(len(methods))}, methods...) if _, err := conn.Write(hello); err != nil { _ = conn.Close() return nil, err } response := make([]byte, 2) if _, err := io.ReadFull(conn, response); err != nil { _ = conn.Close() return nil, err } if response[0] != 0x05 { _ = conn.Close() return nil, fmt.Errorf("invalid socks5 version %d", response[0]) } if response[1] == 0xFF { _ = conn.Close() return nil, errors.New("socks5 proxy has no acceptable auth method") } if response[1] == 0x02 { if err := writeSOCKS5UserPassAuth(conn, proxy.Username, proxy.Password); err != nil { _ = conn.Close() return nil, err } } if err := writeSOCKS5Connect(conn, targetAddr); err != nil { _ = conn.Close() return nil, err } if err := readSOCKS5ConnectResponse(conn); err != nil { _ = conn.Close() return nil, err } return conn, nil } func writeSOCKS5UserPassAuth(conn net.Conn, username string, password string) error { if len(username) > 255 || len(password) > 255 { return errors.New("socks5 username/password too long") } request := make([]byte, 0, 3+len(username)+len(password)) request = append(request, 0x01, byte(len(username))) request = append(request, []byte(username)...) request = append(request, byte(len(password))) request = append(request, []byte(password)...) if _, err := conn.Write(request); err != nil { return err } response := make([]byte, 2) if _, err := io.ReadFull(conn, response); err != nil { return err } if response[1] != 0x00 { return errors.New("socks5 username/password authentication failed") } return nil } func writeSOCKS5Connect(conn net.Conn, targetAddr string) error { host, portString, err := net.SplitHostPort(targetAddr) if err != nil { return err } port, err := strconv.Atoi(portString) if err != nil { return err } if port < 0 || port > 65535 { return fmt.Errorf("invalid port %d", port) } request := []byte{0x05, 0x01, 0x00} if ip := net.ParseIP(host); ip != nil { if ip4 := ip.To4(); ip4 != nil { request = append(request, 0x01) request = append(request, ip4...) } else { request = append(request, 0x04) request = append(request, ip.To16()...) } } else { if len(host) > 255 { return errors.New("socks5 target host too long") } request = append(request, 0x03, byte(len(host))) request = append(request, []byte(host)...) } request = append(request, byte(port>>8), byte(port)) _, err = conn.Write(request) return err } func readSOCKS5ConnectResponse(conn net.Conn) error { header := make([]byte, 4) if _, err := io.ReadFull(conn, header); err != nil { return err } if header[0] != 0x05 { return fmt.Errorf("invalid socks5 response version %d", header[0]) } if header[1] != 0x00 { return fmt.Errorf("socks5 connect failed with code %d", header[1]) } switch header[3] { case 0x01: if _, err := io.ReadFull(conn, make([]byte, 4)); err != nil { return err } case 0x03: size := make([]byte, 1) if _, err := io.ReadFull(conn, size); err != nil { return err } if _, err := io.ReadFull(conn, make([]byte, int(size[0]))); err != nil { return err } case 0x04: if _, err := io.ReadFull(conn, make([]byte, 16)); err != nil { return err } default: return fmt.Errorf("unsupported socks5 bind address type %d", header[3]) } _, err := io.ReadFull(conn, make([]byte, 2)) return err } func applyConnDeadline(conn net.Conn, ctx context.Context, timeout time.Duration) func() { if conn == nil { return func() {} } var ( deadline time.Time hasValue bool ) if ctx != nil { if ctxDeadline, ok := ctx.Deadline(); ok { deadline = ctxDeadline hasValue = true } } if timeout > 0 { timeoutDeadline := time.Now().Add(timeout) if !hasValue || timeoutDeadline.Before(deadline) { deadline = timeoutDeadline hasValue = true } } if !hasValue { return func() {} } _ = conn.SetDeadline(deadline) return func() { _ = conn.SetDeadline(time.Time{}) } } func normalizeProxyConfig(proxy *ProxyConfig, defaultTimeout time.Duration) *ProxyConfig { if proxy == nil { return nil } normalized := *proxy normalized.Type = ProxyType(strings.ToLower(strings.TrimSpace(string(normalized.Type)))) normalized.Addr = strings.TrimSpace(normalized.Addr) if normalized.Timeout <= 0 { normalized.Timeout = defaultTimeout } return &normalized } func joinHostPort(host string, port int) string { return net.JoinHostPort(strings.TrimSpace(host), strconv.Itoa(port)) }