package starssh import ( "context" "errors" "sync" "time" ) const defaultExecPoolMaxOpenConns = 4 var ErrExecPoolClosed = errors.New("exec pool is closed") type ExecPoolConfig struct { Login LoginInput MaxOpenConns int MaxIdleConns int MaxIdleTime time.Duration DisableHealthCheck bool HealthCheckTimeout time.Duration } type ExecPoolStats struct { MaxOpenConns int MaxIdleConns int MaxIdleTime time.Duration OpenConns int IdleConns int InUseConns int } type ExecPool struct { loginInfo LoginInput maxOpen int maxIdle int maxIdleTime time.Duration idle chan *pooledClient done chan struct{} closeOnce sync.Once healthCheckOnAcquire bool healthCheckTimeout time.Duration mu sync.Mutex open int closed bool } type pooledClient struct { client *StarSSH idleAt time.Time } func NewExecPool(config ExecPoolConfig) *ExecPool { maxOpen := config.MaxOpenConns if maxOpen <= 0 { maxOpen = defaultExecPoolMaxOpenConns } maxIdle := config.MaxIdleConns if maxIdle <= 0 || maxIdle > maxOpen { maxIdle = maxOpen } return &ExecPool{ loginInfo: config.Login, maxOpen: maxOpen, maxIdle: maxIdle, maxIdleTime: normalizeMaxIdleTime(config.MaxIdleTime), idle: make(chan *pooledClient, maxIdle), done: make(chan struct{}), healthCheckOnAcquire: !config.DisableHealthCheck, healthCheckTimeout: normalizeHealthCheckTimeout(config.HealthCheckTimeout), } } func (p *ExecPool) Exec(ctx context.Context, req ExecRequest) (*ExecResult, error) { client, err := p.Acquire(ctx) if err != nil { return nil, err } result, execErr := client.Exec(ctx, req) if execErr != nil { p.Discard(client) return result, execErr } if releaseErr := p.Release(client); releaseErr != nil { return result, releaseErr } return result, nil } func (p *ExecPool) ExecString(ctx context.Context, command string) (*ExecResult, error) { return p.Exec(ctx, ExecRequest{ Command: command, }) } func (p *ExecPool) ExecStream(ctx context.Context, req ExecRequest, onChunk func(ExecStreamChunk)) (*ExecResult, error) { client, err := p.Acquire(ctx) if err != nil { return nil, err } result, execErr := client.ExecStream(ctx, req, onChunk) if execErr != nil { p.Discard(client) return result, execErr } if releaseErr := p.Release(client); releaseErr != nil { return result, releaseErr } return result, nil } func (p *ExecPool) WarmUp(ctx context.Context, targetIdle int) error { if p == nil { return errors.New("exec pool is nil") } if ctx == nil { ctx = context.Background() } targetIdle = p.normalizeWarmUpTarget(targetIdle) if targetIdle == 0 { return nil } for { if err := ctx.Err(); err != nil { return err } idleCount, create, err := p.tryWarmUp(targetIdle) if err != nil { return err } if idleCount >= targetIdle || !create { return nil } conn, err := LoginContext(ctx, p.loginInfo) if err != nil { p.releaseSlot() return err } if err := p.Release(conn); err != nil { return err } } } func (p *ExecPool) Acquire(ctx context.Context) (*StarSSH, error) { if p == nil { return nil, errors.New("exec pool is nil") } if ctx == nil { ctx = context.Background() } for { idleClient, create, err := p.tryAcquire() if err != nil { return nil, err } if idleClient != nil { client, ok := p.takeIdleClient(ctx, idleClient) if !ok { continue } return client, nil } if create { conn, err := LoginContext(ctx, p.loginInfo) if err != nil { p.releaseSlot() return nil, err } return conn, nil } select { case <-ctx.Done(): return nil, ctx.Err() case <-p.done: return nil, ErrExecPoolClosed case idleClient = <-p.idle: if idleClient == nil { continue } client, ok := p.takeIdleClient(ctx, idleClient) if !ok { continue } return client, nil } } } func (p *ExecPool) Release(client *StarSSH) error { if p == nil { return errors.New("exec pool is nil") } if client == nil { p.releaseSlot() return nil } p.mu.Lock() if p.closed { p.mu.Unlock() p.closeClient(client) return nil } select { case p.idle <- &pooledClient{ client: client, idleAt: time.Now(), }: p.mu.Unlock() return nil default: p.mu.Unlock() p.closeClient(client) return nil } } func (p *ExecPool) Discard(client *StarSSH) { if p == nil { return } if client == nil { p.releaseSlot() return } p.closeClient(client) } func (p *ExecPool) Stats() ExecPoolStats { if p == nil { return ExecPoolStats{} } p.mu.Lock() defer p.mu.Unlock() idleCount := len(p.idle) openCount := p.open inUseCount := openCount - idleCount if inUseCount < 0 { inUseCount = 0 } return ExecPoolStats{ MaxOpenConns: p.maxOpen, MaxIdleConns: p.maxIdle, MaxIdleTime: p.maxIdleTime, OpenConns: openCount, IdleConns: idleCount, InUseConns: inUseCount, } } func (p *ExecPool) Close() error { if p == nil { return nil } var closeErr error p.closeOnce.Do(func() { p.mu.Lock() p.closed = true idleClients := p.drainIdleLocked() p.mu.Unlock() close(p.done) for _, client := range idleClients { if err := client.Close(); err != nil && closeErr == nil { closeErr = err } } }) return closeErr } func (p *ExecPool) CloseIdle() error { if p == nil { return nil } p.mu.Lock() idleClients := p.drainIdleLocked() p.mu.Unlock() var closeErr error for _, client := range idleClients { if err := client.Close(); err != nil && closeErr == nil { closeErr = err } } return closeErr } func (p *ExecPool) tryAcquire() (*pooledClient, bool, error) { p.mu.Lock() defer p.mu.Unlock() if p.closed { return nil, false, ErrExecPoolClosed } select { case client := <-p.idle: return client, false, nil default: } if p.open < p.maxOpen { p.open++ return nil, true, nil } return nil, false, nil } func (p *ExecPool) tryWarmUp(targetIdle int) (int, bool, error) { p.mu.Lock() defer p.mu.Unlock() if p.closed { return 0, false, ErrExecPoolClosed } idleCount := len(p.idle) if idleCount >= targetIdle { return idleCount, false, nil } if p.open >= p.maxOpen { return idleCount, false, nil } p.open++ return idleCount, true, nil } func (p *ExecPool) normalizeWarmUpTarget(targetIdle int) int { if p == nil || p.maxIdle <= 0 { return 0 } if targetIdle <= 0 { return p.maxIdle } if targetIdle > p.maxIdle { return p.maxIdle } return targetIdle } func (p *ExecPool) takeIdleClient(ctx context.Context, idleClient *pooledClient) (*StarSSH, bool) { if idleClient == nil { return nil, false } if idleClient.client == nil { p.releaseSlot() return nil, false } if p.isIdleExpired(idleClient) { p.closePooledClient(idleClient) return nil, false } if err := p.healthCheckClient(ctx, idleClient.client); err != nil { p.closePooledClient(idleClient) return nil, false } return idleClient.client, true } func (p *ExecPool) isIdleExpired(client *pooledClient) bool { if p == nil || client == nil || client.client == nil { return false } if p.maxIdleTime <= 0 || client.idleAt.IsZero() { return false } return time.Since(client.idleAt) >= p.maxIdleTime } func (p *ExecPool) drainIdleLocked() []*StarSSH { clients := make([]*StarSSH, 0, len(p.idle)) for { select { case idleClient := <-p.idle: if p.open > 0 { p.open-- } if idleClient == nil || idleClient.client == nil { continue } clients = append(clients, idleClient.client) default: return clients } } } func (p *ExecPool) releaseSlot() { p.mu.Lock() defer p.mu.Unlock() if p.open > 0 { p.open-- } } func (p *ExecPool) closeClient(client *StarSSH) { if client != nil { _ = client.Close() } p.releaseSlot() } func (p *ExecPool) closePooledClient(client *pooledClient) { if client == nil { return } if client.client == nil { p.releaseSlot() return } p.closeClient(client.client) } func (p *ExecPool) healthCheckClient(ctx context.Context, client *StarSSH) error { if client == nil { return errors.New("ssh client is nil") } if !p.healthCheckOnAcquire { return nil } if ctx == nil { ctx = context.Background() } timeout := p.healthCheckTimeout if timeout > 0 { healthCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() ctx = healthCtx } return client.PingContext(ctx) } func normalizeHealthCheckTimeout(timeout time.Duration) time.Duration { if timeout <= 0 { return defaultKeepAliveTimeout } return timeout } func normalizeMaxIdleTime(timeout time.Duration) time.Duration { if timeout <= 0 { return 0 } return timeout }