starssh/transport.go

337 lines
8.2 KiB
Go
Raw Permalink Normal View History

refactor: 重构 starssh 核心运行时并补强 ssh/exec/terminal/sftp 能力 - 拆分原有单体 ssh.go,按职责重组为 types、utils、transport、login、keepalive、session、exec、pool、shell、terminal、forward、hostkey、state 等模块,并补充平台相关实现 - 重做登录与连接运行时,补齐基于 context 的建连、jump/proxy 链路、可配置认证顺序,以及 Unix/Windows 下的 ssh-agent 支持 - 新增正式非交互执行模型 ExecRequest/ExecResult,支持流式输出、溢出统计、超时控制,以及 posix/powershell/cmd/raw 多方言执行 - 保留旧 shell 风格兼容接口,同时让路径/用户探测等 helper 具备跨 shell fallback,避免 Windows 目标继续硬依赖 POSIX 命令 - 新增 TerminalSession 作为原始交互终端基座,提供 IO attach、resize、signal/control、退出状态与关闭原因管理 - 重构端口转发语义,默认复用当前 SSH 连接,并显式提供 detached 的本地/动态转发模式承载隔离场景 - 梳理 keepalive 与取消语义,区分仅取消本次操作和关闭整条连接,并统一连接状态与传输关闭路径 - 围绕新的 session/连接生命周期重做执行池与运行时支撑 - 大幅增强 SFTP 传输链路,补齐更安全的原子替换、校验、进度回调、重试隔离、可复用 client 生命周期与失败语义 - 新增取消语义、keepalive、SFTP、forward、terminal input 等关键回归测试,提升核心链路稳定性
2026-04-26 10:45:39 +08:00
package starssh
import (
"bufio"
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"net/http"
"strconv"
"strings"
"time"
)
type bufferedConn struct {
net.Conn
reader *bufio.Reader
}
func (c *bufferedConn) Read(p []byte) (int, error) {
if c == nil || c.reader == nil {
return 0, io.EOF
}
return c.reader.Read(p)
}
func resolveDialContext(info LoginInput) DialContextFunc {
if info.DialContext != nil {
return info.DialContext
}
dialer := &net.Dialer{
Timeout: info.Timeout,
}
return dialer.DialContext
}
func dialTargetConn(ctx context.Context, info LoginInput) (net.Conn, *StarSSH, error) {
targetAddr := joinHostPort(info.Addr, info.Port)
if info.Jump != nil {
return dialViaJump(ctx, info, targetAddr)
}
dialContext := resolveDialContext(info)
proxyConfig := normalizeProxyConfig(info.Proxy, info.Timeout)
if proxyConfig != nil {
return dialViaProxy(ctx, dialContext, *proxyConfig, targetAddr)
}
conn, err := dialContext(ctx, "tcp", targetAddr)
return conn, nil, err
}
func dialViaJump(ctx context.Context, info LoginInput, targetAddr string) (net.Conn, *StarSSH, error) {
if info.Jump == nil {
return nil, nil, errors.New("jump login info is nil")
}
jumpClient, err := loginWithContext(ctx, *info.Jump)
if err != nil {
return nil, nil, err
}
conn, err := jumpClient.dialTCPContext(ctx, "tcp", targetAddr, jumpClient.Close)
if err != nil {
_ = jumpClient.Close()
return nil, nil, err
}
return conn, jumpClient, nil
}
func dialViaProxy(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, *StarSSH, error) {
if dialContext == nil {
return nil, nil, errors.New("dial context is nil")
}
if strings.TrimSpace(proxy.Addr) == "" {
return nil, nil, errors.New("proxy address is empty")
}
switch proxy.Type {
case ProxyTypeSOCKS5:
conn, err := dialSOCKS5(ctx, dialContext, proxy, targetAddr)
return conn, nil, err
case ProxyTypeHTTPConnect:
conn, err := dialHTTPConnect(ctx, dialContext, proxy, targetAddr)
return conn, nil, err
default:
return nil, nil, fmt.Errorf("unsupported proxy type %q", proxy.Type)
}
}
func dialHTTPConnect(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, error) {
conn, err := dialContext(ctx, "tcp", proxy.Addr)
if err != nil {
return nil, err
}
restoreDeadline := applyConnDeadline(conn, ctx, proxy.Timeout)
defer restoreDeadline()
request := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n", targetAddr, targetAddr)
if proxy.Username != "" || proxy.Password != "" {
token := base64.StdEncoding.EncodeToString([]byte(proxy.Username + ":" + proxy.Password))
request += "Proxy-Authorization: Basic " + token + "\r\n"
}
request += "\r\n"
if _, err := io.WriteString(conn, request); err != nil {
_ = conn.Close()
return nil, err
}
reader := bufio.NewReader(conn)
response, err := http.ReadResponse(reader, &http.Request{Method: http.MethodConnect})
if err != nil {
_ = conn.Close()
return nil, err
}
defer response.Body.Close()
if response.StatusCode < 200 || response.StatusCode >= 300 {
_, _ = io.Copy(io.Discard, io.LimitReader(response.Body, 1024))
_ = conn.Close()
return nil, fmt.Errorf("http CONNECT proxy rejected target %s: %s", targetAddr, response.Status)
}
if reader.Buffered() == 0 {
return conn, nil
}
return &bufferedConn{Conn: conn, reader: reader}, nil
}
func dialSOCKS5(ctx context.Context, dialContext DialContextFunc, proxy ProxyConfig, targetAddr string) (net.Conn, error) {
conn, err := dialContext(ctx, "tcp", proxy.Addr)
if err != nil {
return nil, err
}
restoreDeadline := applyConnDeadline(conn, ctx, proxy.Timeout)
defer restoreDeadline()
methods := []byte{0x00}
useAuth := proxy.Username != "" || proxy.Password != ""
if useAuth {
methods = append(methods, 0x02)
}
hello := append([]byte{0x05, byte(len(methods))}, methods...)
if _, err := conn.Write(hello); err != nil {
_ = conn.Close()
return nil, err
}
response := make([]byte, 2)
if _, err := io.ReadFull(conn, response); err != nil {
_ = conn.Close()
return nil, err
}
if response[0] != 0x05 {
_ = conn.Close()
return nil, fmt.Errorf("invalid socks5 version %d", response[0])
}
if response[1] == 0xFF {
_ = conn.Close()
return nil, errors.New("socks5 proxy has no acceptable auth method")
}
if response[1] == 0x02 {
if err := writeSOCKS5UserPassAuth(conn, proxy.Username, proxy.Password); err != nil {
_ = conn.Close()
return nil, err
}
}
if err := writeSOCKS5Connect(conn, targetAddr); err != nil {
_ = conn.Close()
return nil, err
}
if err := readSOCKS5ConnectResponse(conn); err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil
}
func writeSOCKS5UserPassAuth(conn net.Conn, username string, password string) error {
if len(username) > 255 || len(password) > 255 {
return errors.New("socks5 username/password too long")
}
request := make([]byte, 0, 3+len(username)+len(password))
request = append(request, 0x01, byte(len(username)))
request = append(request, []byte(username)...)
request = append(request, byte(len(password)))
request = append(request, []byte(password)...)
if _, err := conn.Write(request); err != nil {
return err
}
response := make([]byte, 2)
if _, err := io.ReadFull(conn, response); err != nil {
return err
}
if response[1] != 0x00 {
return errors.New("socks5 username/password authentication failed")
}
return nil
}
func writeSOCKS5Connect(conn net.Conn, targetAddr string) error {
host, portString, err := net.SplitHostPort(targetAddr)
if err != nil {
return err
}
port, err := strconv.Atoi(portString)
if err != nil {
return err
}
if port < 0 || port > 65535 {
return fmt.Errorf("invalid port %d", port)
}
request := []byte{0x05, 0x01, 0x00}
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
request = append(request, 0x01)
request = append(request, ip4...)
} else {
request = append(request, 0x04)
request = append(request, ip.To16()...)
}
} else {
if len(host) > 255 {
return errors.New("socks5 target host too long")
}
request = append(request, 0x03, byte(len(host)))
request = append(request, []byte(host)...)
}
request = append(request, byte(port>>8), byte(port))
_, err = conn.Write(request)
return err
}
func readSOCKS5ConnectResponse(conn net.Conn) error {
header := make([]byte, 4)
if _, err := io.ReadFull(conn, header); err != nil {
return err
}
if header[0] != 0x05 {
return fmt.Errorf("invalid socks5 response version %d", header[0])
}
if header[1] != 0x00 {
return fmt.Errorf("socks5 connect failed with code %d", header[1])
}
switch header[3] {
case 0x01:
if _, err := io.ReadFull(conn, make([]byte, 4)); err != nil {
return err
}
case 0x03:
size := make([]byte, 1)
if _, err := io.ReadFull(conn, size); err != nil {
return err
}
if _, err := io.ReadFull(conn, make([]byte, int(size[0]))); err != nil {
return err
}
case 0x04:
if _, err := io.ReadFull(conn, make([]byte, 16)); err != nil {
return err
}
default:
return fmt.Errorf("unsupported socks5 bind address type %d", header[3])
}
_, err := io.ReadFull(conn, make([]byte, 2))
return err
}
func applyConnDeadline(conn net.Conn, ctx context.Context, timeout time.Duration) func() {
if conn == nil {
return func() {}
}
var (
deadline time.Time
hasValue bool
)
if ctx != nil {
if ctxDeadline, ok := ctx.Deadline(); ok {
deadline = ctxDeadline
hasValue = true
}
}
if timeout > 0 {
timeoutDeadline := time.Now().Add(timeout)
if !hasValue || timeoutDeadline.Before(deadline) {
deadline = timeoutDeadline
hasValue = true
}
}
if !hasValue {
return func() {}
}
_ = conn.SetDeadline(deadline)
return func() {
_ = conn.SetDeadline(time.Time{})
}
}
func normalizeProxyConfig(proxy *ProxyConfig, defaultTimeout time.Duration) *ProxyConfig {
if proxy == nil {
return nil
}
normalized := *proxy
normalized.Type = ProxyType(strings.ToLower(strings.TrimSpace(string(normalized.Type))))
normalized.Addr = strings.TrimSpace(normalized.Addr)
if normalized.Timeout <= 0 {
normalized.Timeout = defaultTimeout
}
return &normalized
}
func joinHostPort(host string, port int) string {
return net.JoinHostPort(strings.TrimSpace(host), strconv.Itoa(port))
}