package main import ( "fmt" "io/ioutil" "os" "path/filepath" "regexp" "runtime" "strings" "b612.me/starainrt" "b612.me/sshd" "github.com/spf13/cobra" ) var sftpcmd = &cobra.Command{ Use: "sftp", Short: "sftp上传下载", Long: "sftp上传下载", Run: func(this *cobra.Command, args []string) { d, _ := this.Flags().GetBool("download") s, _ := this.Flags().GetString("src") r, _ := this.Flags().GetString("dst") i, _ := this.Flags().GetString("identify") k, _ := this.Flags().GetString("password") p, _ := this.Flags().GetInt("port") b, _ := this.Flags().GetInt("buffer") g, _ := this.Flags().GetString("regexp") var user, host string var err error if len(args) != 1 { fmt.Println("sftp <[user@]Host> -s -d ") this.Help() return } hosts := strings.Split(args[0], "@") if len(hosts) == 1 { host = hosts[0] user = "root" } else { user = hosts[0] host = hosts[1] } fmt.Println("进行SSH连接……") myssh := new(sshd.StarSSH) err = myssh.Connect(user, k, host, i, p) if err != nil { fmt.Println(err) return } defer myssh.Close() fmt.Println("已连接上……") sftp, err := sshd.CreateSftp(myssh.Client) if err != nil { fmt.Println(err) return } defer sftp.Close() fmt.Println("已建立SFTP……") shell := func(pect float64) { if pect != 100.0 { fmt.Printf("传输已完成:%f%%\r", pect) } else { fmt.Printf("传输已完成:%f%%\n", pect) } } var UploadDir func(string, string) UploadDir = func(fs, remote string) { if runtime.GOOS == "windows" { fs = strings.Replace(fs, "/", "\\", -1) } sftp.MkdirAll(remote) abspath, _ := filepath.Abs(fs) dir, err := ioutil.ReadDir(fs) if err != nil { fmt.Println(err) return } for _, v := range dir { if v.IsDir() { if g != "" { continue } UploadDir(abspath+string(os.PathSeparator)+v.Name(), remote+"/"+v.Name()) } else { if ok, _ := regexp.MatchString(g, v.Name()); !ok { continue } fmt.Println("上传:" + abspath + string(os.PathSeparator) + v.Name()) err = sshd.FtpTransferOutFunc(abspath+string(os.PathSeparator)+v.Name(), remote+"/"+v.Name(), b, shell, sftp) if err != nil { fmt.Println(err) continue } } } } var DownloadDir func(string, string) DownloadDir = func(fs, remote string) { abspath, _ := filepath.Abs(remote) os.MkdirAll(abspath, 0755) dir, err := sftp.ReadDir(fs) if err != nil { fmt.Println("读取错误", err) return } for _, v := range dir { if v.IsDir() { if g != "" { continue } DownloadDir(fs+"/"+v.Name(), abspath+string(os.PathSeparator)+v.Name()) } else { if ok, _ := regexp.MatchString(g, v.Name()); !ok { continue } fmt.Println("下载:" + fs + "/" + v.Name()) err = sshd.FtpTransferInFunc(fs+"/"+v.Name(), abspath+string(os.PathSeparator)+v.Name(), b, shell, sftp) if err != nil { fmt.Println(err) continue } } } } if !d { if !starainrt.Exists(s) { fmt.Println("本地文件或文件夹:" + s + "不存在") return } if starainrt.IsFile(s) { err = sshd.FtpTransferOutFunc(s, r, b, shell, sftp) } else { UploadDir(s, r) } } else { if !myssh.Exists(s) { fmt.Println("远端文件或文件夹:" + s + "不存在") return } stat, err := sftp.Stat(s) if err != nil { fmt.Println("错误:", err) return } if !stat.IsDir() { err = sshd.FtpTransferInFunc(s, r, b, shell, sftp) } else { DownloadDir(s, r) } } if err != nil { fmt.Println(err) } }, } func init() { sftpcmd.Flags().BoolP("download", "D", false, "进行下载") sftpcmd.Flags().StringP("identify", "i", "", "RSA登录密钥") sftpcmd.Flags().StringP("password", "k", "", "登录密码") sftpcmd.Flags().StringP("src", "s", "", "本机路径/若为下载则相反") sftpcmd.Flags().StringP("dst", "d", "", "远程路径/若为下载则相反") sftpcmd.Flags().IntP("port", "p", 22, "登录端口") sftpcmd.Flags().StringP("regexp", "r", "", "正则表达式") sftpcmd.Flags().IntP("buffer", "b", 10240, "buffer大小") }