This commit is contained in:
兔子 2025-06-17 13:10:35 +08:00
parent f77fc6dddf
commit 6b6b5a6f0f
11 changed files with 415 additions and 133 deletions

View File

@ -1,6 +1,7 @@
package cert package cert
import ( import (
"b612.me/apps/b612/utils"
"b612.me/stario" "b612.me/stario"
"b612.me/starlog" "b612.me/starlog"
"crypto" "crypto"
@ -205,7 +206,83 @@ var CmdFastGen = &cobra.Command{
Short: "快速生成证书", Short: "快速生成证书",
Long: "快速生成证书", Long: "快速生成证书",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
return if promptMode {
if fastgen.Country == "" {
fastgen.Country = stario.MessageBox("请输入国家:", "").MustString()
}
if fastgen.Province == "" {
fastgen.Province = stario.MessageBox("请输入省份:", "").MustString()
}
if fastgen.City == "" {
fastgen.City = stario.MessageBox("请输入城市:", "").MustString()
}
if fastgen.Organization == "" {
fastgen.Organization = stario.MessageBox("请输入组织:", "").MustString()
}
if fastgen.OrganizationUnit == "" {
fastgen.OrganizationUnit = stario.MessageBox("请输入组织单位:", "").MustString()
}
if fastgen.CommonName == "" {
fastgen.CommonName = stario.MessageBox("请输入通用名称:", "").MustString()
}
if fastgen.Dns == nil {
fastgen.Dns = stario.MessageBox("请输入dns名称用逗号分割", "").MustSliceString(",")
}
if fastgen.Type == "" {
fastgen.Type = stario.MessageBox("请输入证书类型(RSA/ECDSA)", "RSA").MustString()
}
if fastgen.Bits <= 0 {
fastgen.Bits = stario.MessageBox("请输入证书位数:", "2048").MustInt()
}
if startStr == "" {
startStr = stario.MessageBox("请输入证书开始时间,格式:2006-01-02T15:04:05Z07:00", time.Now().Format(time.RFC3339)).MustString()
}
if endStr == "" {
endStr = stario.MessageBox("请输入证书结束时间,格式:2006-01-02T15:04:05Z07:00", time.Now().AddDate(1, 0, 0).Format(time.RFC3339)).MustString()
}
}
var err error
fastgen.StartDate, err = time.Parse(time.RFC3339, startStr)
if err != nil {
starlog.Errorln("开始时间格式错误,格式:2006-01-02T15:04:05Z07:00", err)
os.Exit(1)
}
fastgen.EndDate, err = time.Parse(time.RFC3339, endStr)
if err != nil {
starlog.Errorln("结束时间格式错误,格式:2006-01-02T15:04:05Z07:00", err)
os.Exit(1)
}
if caCert != "" && caKey != "" {
fastgen.CAPriv, fastgen.CA, err = LoadCA(caKey, caCert, caKeyPwd)
if err != nil {
starlog.Errorln("加载CA错误", err)
os.Exit(1)
}
}
if fastgen.CAPriv == nil {
fastgen.CA, fastgen.CAPriv = utils.ToolCert("")
}
byteCrt, byteKey, err := utils.GenerateCert(fastgen)
if err != nil {
starlog.Errorln("生成证书错误", err)
os.Exit(1)
}
name := fastgen.CommonName
if name == "" {
name = "cert"
}
err = os.WriteFile(filepath.Join(savefolder, name+".crt"), byteCrt, 0644)
if err != nil {
starlog.Errorln("保存证书错误", err)
os.Exit(1)
}
starlog.Infoln("保存证书成功", filepath.Join(savefolder, name+".crt"))
err = os.WriteFile(filepath.Join(savefolder, name+".key"), byteKey, 0644)
if err != nil {
starlog.Errorln("保存私钥错误", err)
os.Exit(1)
}
starlog.Infoln("保存私钥成功", filepath.Join(savefolder, name+".key"))
}, },
} }
@ -230,15 +307,17 @@ var CmdParse = &cobra.Command{
}, },
} }
var fastgen utils.GenerateCertParams
func init() { func init() {
Cmd.AddCommand(CmdCsr) Cmd.AddCommand(CmdCsr)
CmdCsr.Flags().BoolVarP(&promptMode, "prompt", "P", false, "是否交互模式") CmdCsr.Flags().BoolVarP(&promptMode, "prompt", "P", false, "是否交互模式")
CmdCsr.Flags().StringVarP(&country, "country", "c", "CN", "国家") CmdCsr.Flags().StringVarP(&country, "country", "c", "", "国家")
CmdCsr.Flags().StringVarP(&province, "province", "p", "B612", "省份") CmdCsr.Flags().StringVarP(&province, "province", "p", "", "省份")
CmdCsr.Flags().StringVarP(&city, "city", "t", "B612", "城市") CmdCsr.Flags().StringVarP(&city, "city", "t", "", "城市")
CmdCsr.Flags().StringVarP(&org, "org", "o", "", "组织") CmdCsr.Flags().StringVarP(&org, "org", "o", "", "组织")
CmdCsr.Flags().StringVarP(&orgUnit, "orgUnit", "u", "", "组织单位") CmdCsr.Flags().StringVarP(&orgUnit, "orgUnit", "u", "", "组织单位")
CmdCsr.Flags().StringVarP(&name, "name", "n", "Starainrt", "通用名称") CmdCsr.Flags().StringVarP(&name, "name", "n", "", "通用名称")
CmdCsr.Flags().StringSliceVarP(&dnsName, "dnsName", "d", nil, "dns名称") CmdCsr.Flags().StringSliceVarP(&dnsName, "dnsName", "d", nil, "dns名称")
CmdCsr.Flags().StringVarP(&savefolder, "savefolder", "s", "./", "保存文件夹") CmdCsr.Flags().StringVarP(&savefolder, "savefolder", "s", "./", "保存文件夹")
CmdCsr.Flags().StringVarP(&caKey, "secret-key", "k", "", "加密私钥") CmdCsr.Flags().StringVarP(&caKey, "secret-key", "k", "", "加密私钥")
@ -249,7 +328,7 @@ func init() {
//CmdCsr.Flags().BoolVarP(&maxPathLenZero, "maxPathLenZero", "z", false, "允许最大路径长度为0") //CmdCsr.Flags().BoolVarP(&maxPathLenZero, "maxPathLenZero", "z", false, "允许最大路径长度为0")
//CmdCsr.Flags().IntVarP(&maxPathLen, "maxPathLen", "m", 0, "最大路径长度") //CmdCsr.Flags().IntVarP(&maxPathLen, "maxPathLen", "m", 0, "最大路径长度")
CmdGen.Flags().IntVarP(&keyUsage, "keyUsage", "u", 0, "证书使用类型默认数字00表示数字签名和密钥加密1表示证书签名2表示CRL签名4表示密钥协商8表示数据加密") CmdGen.Flags().IntVarP(&keyUsage, "keyUsage", "u", 0, "证书使用类型默认数字00表示数字签名和密钥加密1表示证书签名2表示CRL签名4表示密钥协商8表示数据加密")
CmdGen.Flags().IntSliceVarP(&extKeyUsage, "extKeyUsage", "e", nil, "扩展证书使用类型默认数字00表示服务器认证1表示客户端认证2表示代码签名3表示电子邮件保护4表示IPSEC终端系统5表示IPSEC隧道6表示IPSEC用户7表示时间戳8表示OCSP签名9表示Microsoft服务器网关加密10表示Netscape服务器网关加密11表示Microsoft商业代码签名12表示Microsoft内核代码签名") CmdGen.Flags().IntSliceVarP(&extKeyUsage, "extKeyUsage", "e", []int{0, 1}, "扩展证书使用类型默认数字0和10表示服务器认证1表示客户端认证2表示代码签名3表示电子邮件保护4表示IPSEC终端系统5表示IPSEC隧道6表示IPSEC用户7表示时间戳8表示OCSP签名9表示Microsoft服务器网关加密10表示Netscape服务器网关加密11表示Microsoft商业代码签名12表示Microsoft内核代码签名")
CmdGen.Flags().StringVarP(&caKey, "caKey", "k", "", "CA私钥") CmdGen.Flags().StringVarP(&caKey, "caKey", "k", "", "CA私钥")
CmdGen.Flags().StringVarP(&caCert, "caCert", "C", "", "CA证书") CmdGen.Flags().StringVarP(&caCert, "caCert", "C", "", "CA证书")
CmdGen.Flags().StringVarP(&csr, "csr", "r", "", "证书请求") CmdGen.Flags().StringVarP(&csr, "csr", "r", "", "证书请求")
@ -289,21 +368,27 @@ func init() {
Cmd.AddCommand(CmdOpenssh) Cmd.AddCommand(CmdOpenssh)
CmdFastGen.Flags().BoolVarP(&promptMode, "prompt", "P", false, "是否交互模式") CmdFastGen.Flags().BoolVarP(&promptMode, "prompt", "P", false, "是否交互模式")
CmdFastGen.Flags().StringVarP(&country, "country", "c", "CN", "国家") CmdFastGen.Flags().StringVarP(&fastgen.Country, "country", "c", "", "国家")
CmdFastGen.Flags().StringVarP(&province, "province", "p", "B612", "省份") CmdFastGen.Flags().StringVarP(&fastgen.Province, "province", "p", "", "省份")
CmdFastGen.Flags().StringVarP(&city, "city", "t", "B612", "城市") CmdFastGen.Flags().StringVar(&fastgen.City, "city", "", "城市")
CmdFastGen.Flags().StringVarP(&org, "org", "o", "", "组织") CmdFastGen.Flags().StringVarP(&fastgen.Organization, "org", "o", "", "组织")
CmdFastGen.Flags().StringVarP(&orgUnit, "orgUnit", "u", "", "组织单位") CmdFastGen.Flags().StringVarP(&fastgen.OrganizationUnit, "orgUnit", "u", "", "组织单位")
CmdFastGen.Flags().StringVarP(&name, "name", "n", "Starainrt", "通用名称") CmdFastGen.Flags().StringVarP(&fastgen.CommonName, "name", "n", "", "通用名称")
CmdFastGen.Flags().StringSliceVarP(&dnsName, "dnsName", "d", nil, "dns名称") CmdFastGen.Flags().StringSliceVarP(&fastgen.Dns, "dnsName", "d", nil, "dns名称")
CmdFastGen.Flags().StringVarP(&savefolder, "savefolder", "s", "./", "保存文件夹") CmdFastGen.Flags().StringVarP(&savefolder, "savefolder", "s", "./", "保存文件夹")
CmdFastGen.Flags().IntVarP(&keyUsage, "keyUsage", "U", 0, "证书使用类型默认数字00表示数字签名和密钥加密1表示证书签名2表示CRL签名4表示密钥协商8表示数据加密") CmdFastGen.Flags().IntVarP(&fastgen.KeyUsage, "keyUsage", "U", 0, "证书使用类型默认数字00表示数字签名和密钥加密1表示证书签名2表示CRL签名4表示密钥协商8表示数据加密")
CmdFastGen.Flags().IntSliceVarP(&extKeyUsage, "extKeyUsage", "e", nil, "扩展证书使用类型默认数字00表示服务器认证1表示客户端认证2表示代码签名3表示电子邮件保护4表示IPSEC终端系统5表示IPSEC隧道6表示IPSEC用户7表示时间戳8表示OCSP签名9表示Microsoft服务器网关加密10表示Netscape服务器网关加密11表示Microsoft商业代码签名12表示Microsoft内核代码签名") CmdFastGen.Flags().IntSliceVarP(&fastgen.ExtendedKeyUsage, "extKeyUsage", "e", []int{0, 1}, "扩展证书使用类型默认数字0和10表示服务器认证1表示客户端认证2表示代码签名3表示电子邮件保护4表示IPSEC终端系统5表示IPSEC隧道6表示IPSEC用户7表示时间戳8表示OCSP签名9表示Microsoft服务器网关加密10表示Netscape服务器网关加密11表示Microsoft商业代码签名12表示Microsoft内核代码签名")
CmdFastGen.Flags().BoolVarP(&isCa, "isCa", "A", false, "是否是CA") CmdFastGen.Flags().BoolVarP(&fastgen.IsCA, "isCa", "A", false, "是否是CA")
CmdFastGen.Flags().StringVarP(&startStr, "start", "S", time.Now().Format(time.RFC3339), "开始时间,格式:2006-01-02T15:04:05Z07:00") CmdFastGen.Flags().StringVarP(&startStr, "start", "S", time.Now().Format(time.RFC3339), "开始时间,格式:2006-01-02T15:04:05Z07:00")
CmdFastGen.Flags().StringVarP(&endStr, "end", "E", time.Now().AddDate(1, 0, 0).Format(time.RFC3339), "结束时间,格式:2006-01-02T15:04:05Z07:00") CmdFastGen.Flags().StringVarP(&endStr, "end", "E", time.Now().AddDate(1, 0, 0).Format(time.RFC3339), "结束时间,格式:2006-01-02T15:04:05Z07:00")
CmdFastGen.Flags().BoolVarP(&maxPathLenZero, "maxPathLenZero", "z", false, "允许最大路径长度为0") CmdFastGen.Flags().BoolVarP(&fastgen.MaxPathLengthZero, "maxPathLenZero", "z", false, "允许最大路径长度为0")
CmdFastGen.Flags().IntVarP(&maxPathLen, "maxPathLen", "m", 0, "最大路径长度") CmdFastGen.Flags().IntVarP(&fastgen.MaxPathLength, "maxPathLen", "m", 0, "最大路径长度")
CmdFastGen.Flags().StringVarP(&caKey, "caKey", "K", "", "CA私钥可以留空")
CmdFastGen.Flags().StringVarP(&caCert, "caCert", "C", "", "CA证书可以留空")
CmdFastGen.Flags().StringVar(&caKeyPwd, "caKeyPwd", "", "CA私钥密码")
CmdFastGen.Flags().StringVarP(&fastgen.Type, "type", "t", "RSA", "证书类型支持RSA和ECDSA")
CmdFastGen.Flags().IntVarP(&fastgen.Bits, "bits", "b", 2048, "证书位数默认2048")
Cmd.AddCommand(CmdFastGen)
} }
var CmdPkcs8 = &cobra.Command{ var CmdPkcs8 = &cobra.Command{

2
go.mod
View File

@ -14,7 +14,7 @@ require (
b612.me/stario v0.0.10 b612.me/stario v0.0.10
b612.me/starlog v1.3.4 b612.me/starlog v1.3.4
b612.me/starmap v0.0.0-20240818092703-ae61140c5062 b612.me/starmap v0.0.0-20240818092703-ae61140c5062
b612.me/starnet v0.0.0-20250612085047-7a1767214960 b612.me/starnet v0.0.0-20250617043657-c1eaf4305803
b612.me/staros v1.1.8 b612.me/staros v1.1.8
b612.me/starssh v0.0.2 b612.me/starssh v0.0.2
b612.me/startext v0.0.0-20220314043758-22c6d5e5b1cd b612.me/startext v0.0.0-20220314043758-22c6d5e5b1cd

4
go.sum
View File

@ -17,8 +17,8 @@ b612.me/starlog v1.3.4 h1:XuVYo6NCij8F4TGSgtEuMhs1WkZ7HZNnYUgQ3nLTt84=
b612.me/starlog v1.3.4/go.mod h1:37GMgkWQMOAjzKs49Hf2i8bLwdXbd9QF4zKhUxFDoSk= b612.me/starlog v1.3.4/go.mod h1:37GMgkWQMOAjzKs49Hf2i8bLwdXbd9QF4zKhUxFDoSk=
b612.me/starmap v0.0.0-20240818092703-ae61140c5062 h1:ImKEWAxzBYsS/YbqdVOPdUdv6b+i/lSGpipUGueXk7w= b612.me/starmap v0.0.0-20240818092703-ae61140c5062 h1:ImKEWAxzBYsS/YbqdVOPdUdv6b+i/lSGpipUGueXk7w=
b612.me/starmap v0.0.0-20240818092703-ae61140c5062/go.mod h1:PhtO9wFrwPIHpry2CEdnVNZkrNOgfv77xrE0ZKQDkLM= b612.me/starmap v0.0.0-20240818092703-ae61140c5062/go.mod h1:PhtO9wFrwPIHpry2CEdnVNZkrNOgfv77xrE0ZKQDkLM=
b612.me/starnet v0.0.0-20250612085047-7a1767214960 h1:LRB59HvHW6D4sAKd/P2GDgmnGyx5DM8aSGnJoEifeI0= b612.me/starnet v0.0.0-20250617043657-c1eaf4305803 h1:ppTZxCYigi2wBElUVtvdDIlFjR4/tJT2O3X1ILjDHZA=
b612.me/starnet v0.0.0-20250612085047-7a1767214960/go.mod h1:6q+AXhYeXsIiKV+hZZmqAMn8S48QcdonURJyH66rbzI= b612.me/starnet v0.0.0-20250617043657-c1eaf4305803/go.mod h1:6q+AXhYeXsIiKV+hZZmqAMn8S48QcdonURJyH66rbzI=
b612.me/staros v1.1.8 h1:5Bpuf9q2nH75S2ekmieJuH3Y8LTqg/voxXCOiMAC3kk= b612.me/staros v1.1.8 h1:5Bpuf9q2nH75S2ekmieJuH3Y8LTqg/voxXCOiMAC3kk=
b612.me/staros v1.1.8/go.mod h1:4KmokjKXFW5h1hbA4aIv5O+2FptVzBubCo7IPirfqm8= b612.me/staros v1.1.8/go.mod h1:4KmokjKXFW5h1hbA4aIv5O+2FptVzBubCo7IPirfqm8=
b612.me/starssh v0.0.2 h1:cYlrXjd7ZTesdZG+7XcoLsEEMROaeWMTYonScBLnvyY= b612.me/starssh v0.0.2 h1:cYlrXjd7ZTesdZG+7XcoLsEEMROaeWMTYonScBLnvyY=

View File

@ -102,7 +102,7 @@ func autoGenCert(hostname string) *tls.Config {
return &tls.Config{Certificates: []tls.Certificate{cert}} return &tls.Config{Certificates: []tls.Certificate{cert}}
} }
if toolCa == nil { if toolCa == nil {
toolCa, toolCaKey = utils.ToolCert() toolCa, toolCaKey = utils.ToolCert("")
} }
cert, err := utils.GenerateTlsCert(utils.GenerateCertParams{ cert, err := utils.GenerateTlsCert(utils.GenerateCertParams{
Country: "CN", Country: "CN",

View File

@ -22,7 +22,7 @@ var speedlimit string
func init() { func init() {
Cmd.Flags().StringVarP(&hooks, "hook", "H", "", "fileget hook for modify") Cmd.Flags().StringVarP(&hooks, "hook", "H", "", "fileget hook for modify")
Cmd.Flags().StringVarP(&s.port, "port", "p", "80", "监听端口") Cmd.Flags().StringVarP(&s.port, "port", "p", "", "监听端口,http时默认80,https时默认443")
Cmd.Flags().StringVarP(&s.addr, "ip", "i", "0.0.0.0", "监听ip") Cmd.Flags().StringVarP(&s.addr, "ip", "i", "0.0.0.0", "监听ip")
Cmd.Flags().StringVarP(&s.envPath, "folder", "f", "./", "本地文件地址") Cmd.Flags().StringVarP(&s.envPath, "folder", "f", "./", "本地文件地址")
Cmd.Flags().StringVarP(&s.uploadFolder, "upload", "u", "", "文件上传文件夹路径") Cmd.Flags().StringVarP(&s.uploadFolder, "upload", "u", "", "文件上传文件夹路径")
@ -98,6 +98,13 @@ var Cmd = &cobra.Command{
if s.logpath != "" && starlog.GetWriter() == nil { if s.logpath != "" && starlog.GetWriter() == nil {
starlog.SetLogFile(s.logpath, starlog.Std, true) starlog.SetLogFile(s.logpath, starlog.Std, true)
} }
if s.port == "" {
if s.cert != "" && s.key != "" || s.autoGenCert {
s.port = "443"
} else {
s.port = "80"
}
}
if speedlimit != "" { if speedlimit != "" {
speed, err := parseSpeedString(speedlimit) speed, err := parseSpeedString(speedlimit)
if err != nil { if err != nil {

View File

@ -129,6 +129,15 @@ func (h *HttpServer) Run(ctx context.Context) error {
server := http.Server{ server := http.Server{
Addr: h.addr + ":" + h.port, Addr: h.addr + ":" + h.port,
Handler: h, Handler: h,
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
switch conn := c.(type) {
case *tls.Conn:
return context.WithValue(ctx, "istls", true)
case *starnet.Conn:
return context.WithValue(ctx, "istls", conn)
}
return context.WithValue(ctx, "istls", false)
},
} }
go func() { go func() {
select { select {
@ -209,7 +218,7 @@ func autoGenCert(hostname string) *tls.Config {
return &tls.Config{Certificates: []tls.Certificate{cert}} return &tls.Config{Certificates: []tls.Certificate{cert}}
} }
if toolCa == nil { if toolCa == nil {
toolCa, toolCaKey = utils.ToolCert() toolCa, toolCaKey = utils.ToolCert("")
} }
cert, err := utils.GenerateTlsCert(utils.GenerateCertParams{ cert, err := utils.GenerateTlsCert(utils.GenerateCertParams{
Country: "CN", Country: "CN",
@ -413,6 +422,7 @@ func (h *HttpServer) Listen(w http.ResponseWriter, r *http.Request) {
path = filepath.Join(path, h.indexFile) path = filepath.Join(path, h.indexFile)
} }
} }
isTls := h.isTlsStr(r)
now := time.Now() now := time.Now()
if h.SetUpload(w, r, path) { if h.SetUpload(w, r, path) {
return return
@ -421,23 +431,23 @@ func (h *HttpServer) Listen(w http.ResponseWriter, r *http.Request) {
case "OPTIONS", "HEAD": case "OPTIONS", "HEAD":
err := h.BuildHeader(w, r, fullpath) err := h.BuildHeader(w, r, fullpath)
if err != nil { if err != nil {
log.Warningf("%s %s From %s %s %.2fs %v\n", r.Method, path, r.RemoteAddr, ua, time.Since(now).Seconds(), err) log.Warningf(isTls+"%s %s From %s %s %.2fs %v\n", r.Method, path, r.RemoteAddr, ua, time.Since(now).Seconds(), err)
} else { } else {
log.Infof("%s %s From %s %s %.2fs \n", r.Method, path, r.RemoteAddr, ua, time.Since(now).Seconds()) log.Infof(isTls+"%s %s From %s %s %.2fs \n", r.Method, path, r.RemoteAddr, ua, time.Since(now).Seconds())
} }
case "GET": case "GET":
err := h.BuildHeader(w, r, fullpath) err := h.BuildHeader(w, r, fullpath)
if err != nil { if err != nil {
log.Warningf("GET Header Build Failed Path:%s IP:%s Err:%v\n", path, r.RemoteAddr, err) log.Warningf(isTls+"GET Header Build Failed Path:%s IP:%s Err:%v\n", path, r.RemoteAddr, err)
} }
err = h.ResponseGet(log, w, r, fullpath) err = h.ResponseGet(log, w, r, fullpath)
if err != nil { if err != nil {
log.Warningf("%s %s From %s %s %.2fs %v\n", r.Method, path, r.RemoteAddr, ua, time.Since(now).Seconds(), err) log.Warningf(isTls+"%s %s From %s %s %.2fs %v\n", r.Method, path, r.RemoteAddr, ua, time.Since(now).Seconds(), err)
return return
} }
log.Infof("%s %s From %s %s %.2fs\n", r.Method, path, r.RemoteAddr, ua, time.Since(now).Seconds()) log.Infof(isTls+"%s %s From %s %s %.2fs\n", r.Method, path, r.RemoteAddr, ua, time.Since(now).Seconds())
default: default:
log.Errorf("Invalid %s %s From %s %s %.2fs\n", r.Method, path, r.RemoteAddr, ua, time.Since(now).Seconds()) log.Errorf(isTls+"Invalid %s %s From %s %s %.2fs\n", r.Method, path, r.RemoteAddr, ua, time.Since(now).Seconds())
return return
} }
} }
@ -703,11 +713,33 @@ func (h *HttpServer) getSleepTime() time.Duration {
} }
func (h *HttpServer) isTls(r *http.Request) bool {
if r.Context().Value("istls") != nil {
if v, ok := r.Context().Value("istls").(bool); ok {
if v {
return true
}
}
if v, ok := r.Context().Value("istls").(*starnet.Conn); ok {
return v.IsTLS()
}
}
return false
}
func (h *HttpServer) isTlsStr(r *http.Request) string {
if h.isTls(r) {
return "TLS "
}
return "PLN "
}
func (h *HttpServer) getFile(log *starlog.StarLogger, w http.ResponseWriter, r *http.Request, fullpath string) error { func (h *HttpServer) getFile(log *starlog.StarLogger, w http.ResponseWriter, r *http.Request, fullpath string) error {
if !staros.Exists(fullpath) { if !staros.Exists(fullpath) {
h.Page404(w) h.Page404(w)
return errors.New("File Not Found! 404 ERROR") return errors.New("File Not Found! 404 ERROR")
} }
isTls := h.isTlsStr(r)
var lastCount int64 var lastCount int64
var lastDate time.Time = time.Now() var lastDate time.Time = time.Now()
var currentCount int64 var currentCount int64
@ -754,7 +786,7 @@ func (h *HttpServer) getFile(log *starlog.StarLogger, w http.ResponseWriter, r *
} }
} }
} }
log.Infof("Tranfered File %s %d bytes (%s) to remote %v\n", r.URL.Path, log.Infof(isTls+"Tranfered File %s %d bytes (%s) to remote %v\n", r.URL.Path,
transferData, tani, r.RemoteAddr) transferData, tani, r.RemoteAddr)
} }
}() }()
@ -770,7 +802,7 @@ func (h *HttpServer) getFile(log *starlog.StarLogger, w http.ResponseWriter, r *
ns, err := w.Write(buf[:n]) ns, err := w.Write(buf[:n])
transferData += ns transferData += ns
if err != nil { if err != nil {
log.Errorf("Transfer File %s to Remote Failed:%v\n", fullpath, err) log.Errorf(isTls+"Transfer File %s to Remote Failed:%v\n", fullpath, err)
return err return err
} }
} }
@ -805,7 +837,7 @@ func (h *HttpServer) getFile(log *starlog.StarLogger, w http.ResponseWriter, r *
ns, err := w.Write(data) ns, err := w.Write(data)
transferData += ns transferData += ns
if err != nil { if err != nil {
log.Errorf("Transfer File %s to Remote Failed:%v\n", fullpath, err) log.Errorf(isTls+"Transfer File %s to Remote Failed:%v\n", fullpath, err)
return err return err
} }
return nil return nil
@ -816,12 +848,12 @@ func (h *HttpServer) getFile(log *starlog.StarLogger, w http.ResponseWriter, r *
ns, err := w.Write(recvData) ns, err := w.Write(recvData)
transferData += ns transferData += ns
if err != nil { if err != nil {
log.Errorf("Transfer File %s to Remote Failed:%v\n", fullpath, err) log.Errorf(isTls+"Transfer File %s to Remote Failed:%v\n", fullpath, err)
return err return err
} }
return nil return nil
} }
log.Debugf("206 transfer mode for %v %v start %v end %v\n", r.URL.Path, r.RemoteAddr, startRange, endRange) log.Debugf(isTls+"206 transfer mode for %v %v start %v end %v\n", r.URL.Path, r.RemoteAddr, startRange, endRange)
w.WriteHeader(206) w.WriteHeader(206)
fp.Seek(int64(startRange), 0) fp.Seek(int64(startRange), 0)
count := startRange count := startRange
@ -832,7 +864,7 @@ func (h *HttpServer) getFile(log *starlog.StarLogger, w http.ResponseWriter, r *
if err == io.EOF { if err == io.EOF {
break break
} }
log.Errorf("Read File %s Failed:%v\n", r.URL.Path, err) log.Errorf(isTls+"Read File %s Failed:%v\n", r.URL.Path, err)
return err return err
} }
speedControl(n) speedControl(n)
@ -840,7 +872,7 @@ func (h *HttpServer) getFile(log *starlog.StarLogger, w http.ResponseWriter, r *
ns, err := w.Write(buf[:n]) ns, err := w.Write(buf[:n])
transferData += ns transferData += ns
if err != nil { if err != nil {
log.Errorf("Transfer File %s to Remote Failed:%v\n", r.URL.Path, err) log.Errorf(isTls+"Transfer File %s to Remote Failed:%v\n", r.URL.Path, err)
return err return err
} }
} else { } else {

View File

@ -3,6 +3,7 @@ package netforward
import ( import (
"b612.me/stario" "b612.me/stario"
"b612.me/starlog" "b612.me/starlog"
"crypto/tls"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"os" "os"
"os/signal" "os/signal"
@ -30,6 +31,18 @@ func init() {
CmdNetforward.Flags().IntVarP(&f.UserTimeout, "user-timeout", "U", 0, "user timeout (milliseconds)") CmdNetforward.Flags().IntVarP(&f.UserTimeout, "user-timeout", "U", 0, "user timeout (milliseconds)")
CmdNetforward.Flags().BoolVarP(&f.IgnoreEof, "ignore-eof", "E", false, "ignore eof") CmdNetforward.Flags().BoolVarP(&f.IgnoreEof, "ignore-eof", "E", false, "ignore eof")
CmdNetforward.Flags().BoolVarP(&f.Verbose, "verbose", "v", false, "verbose mode") CmdNetforward.Flags().BoolVarP(&f.Verbose, "verbose", "v", false, "verbose mode")
CmdNetforward.Flags().BoolVarP(&f.inTls, "in-tls", "i", false, "enable input tls")
CmdNetforward.Flags().BoolVarP(&f.outTls, "out-tls", "o", false, "enable output tls")
CmdNetforward.Flags().StringVar(&f.inTlsCert, "in-cert", "", "tls cert file")
CmdNetforward.Flags().StringVar(&f.inTlsKey, "in-key", "", "tls key file")
CmdNetforward.Flags().StringVar(&f.outTlsCert, "out-cert", "", "tls cert file")
CmdNetforward.Flags().StringVar(&f.outTlsKey, "out-key", "", "tls key file")
CmdNetforward.Flags().BoolVar(&f.inTlsSkipVerify, "in-skip", false, "skip verify input tls cert")
CmdNetforward.Flags().BoolVar(&f.outTlsSkipVerify, "out-skip", false, "skip verify output tls cert")
CmdNetforward.Flags().BoolVarP(&f.inTlsAutoGen, "in-autogen", "G", false, "auto generate input tls cert")
CmdNetforward.Flags().StringSliceVar(&f.CaCerts, "ca", []string{}, "tls ca certs")
CmdNetforward.Flags().BoolVarP(&f.allowNoTls, "allow-no-tls", "A", false, "allow no tls connection")
} }
var CmdNetforward = &cobra.Command{ var CmdNetforward = &cobra.Command{
@ -41,6 +54,7 @@ var CmdNetforward = &cobra.Command{
starlog.Errorln("please enter a target uri") starlog.Errorln("please enter a target uri")
os.Exit(1) os.Exit(1)
} }
f.certCache = make(map[string]tls.Certificate)
f.RemoteURI = strings.TrimSpace(args[0]) f.RemoteURI = strings.TrimSpace(args[0])
if dialTimeout == 0 { if dialTimeout == 0 {
dialTimeout = 10000 dialTimeout = 10000

View File

@ -1,14 +1,19 @@
package netforward package netforward
import ( import (
"b612.me/apps/b612/utils"
"b612.me/stario" "b612.me/stario"
"b612.me/starlog" "b612.me/starlog"
"b612.me/starmap" "b612.me/starmap"
"b612.me/starnet"
"context" "context"
"crypto/tls"
"crypto/x509"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net" "net"
"os"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -41,6 +46,23 @@ type NetForward struct {
UsingKeepAlive bool UsingKeepAlive bool
Verbose bool Verbose bool
udpListener *net.UDPConn udpListener *net.UDPConn
inTls bool // 是否启用TLS
outTls bool // 是否启用TLS
inTlsCert string // TLS证书路径
inTlsKey string // TLS密钥路径
inTlsAutoGen bool // 是否自动生成TLS证书
CaCerts []string // TLS CA证书路径
outTlsKey string // TLS密钥路径
outTlsCert string // TLS证书路径
inTlsSkipVerify bool // 是否跳过TLS验证
outTlsSkipVerify bool // 是否跳过TLS验证
allowNoTls bool // 是否允许不使用TLS
certCache map[string]tls.Certificate
toolCa *x509.Certificate
toolCaKey any
caPool *x509.CertPool
outTlsCertCache tls.Certificate
} }
func (n *NetForward) UdpListener() *net.UDPConn { func (n *NetForward) UdpListener() *net.UDPConn {
@ -133,15 +155,110 @@ func (n *NetForward) Run() error {
return nil return nil
} }
func (n *NetForward) runTCP() error { func (n *NetForward) TcpListener() (net.Listener, error) {
atomic.AddInt32(&n.running, 1) if n.outTls && n.outTlsCert != "" && n.outTlsKey != "" {
defer atomic.AddInt32(&n.running, -1) cert, err := tls.LoadX509KeyPair(n.outTlsCert, n.outTlsKey)
if err != nil {
starlog.Errorln("Load X509 Key Pair Failed:", err)
return nil, err
}
n.outTlsCertCache = cert
}
cfg := net.ListenConfig{ cfg := net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error { Control: func(network, address string, c syscall.RawConn) error {
return c.Control(SetReUseAddr) return c.Control(SetReUseAddr)
}, },
} }
listen, err := cfg.Listen(context.Background(), "tcp", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort)) listener, err := cfg.Listen(context.Background(), "tcp", fmt.Sprintf("%s:%d", n.LocalAddr, n.LocalPort))
if !n.inTls {
return listener, err
}
var caPool *x509.CertPool
if n.inTlsAutoGen {
if n.toolCa == nil {
n.toolCa, n.toolCaKey = utils.ToolCert("")
if n.toolCa != nil {
caPool = x509.NewCertPool()
caPool.AddCert(n.toolCa)
}
}
}
if len(n.CaCerts) > 0 {
if caPool == nil {
caPool = x509.NewCertPool()
}
for _, ca := range n.CaCerts {
data, err := os.ReadFile(ca)
if err != nil {
starlog.Errorln("Read CA Cert Failed:", err)
listener.Close()
return nil, err
}
if !caPool.AppendCertsFromPEM(data) {
starlog.Errorln("Append CA Cert Failed:", ca)
listener.Close()
return nil, fmt.Errorf("append ca cert %s failed", ca)
}
}
n.caPool = caPool
}
var tlsConfig = &tls.Config{
Certificates: nil,
RootCAs: caPool,
InsecureSkipVerify: n.inTlsSkipVerify,
}
if !n.inTlsAutoGen && (n.inTlsCert != "" || n.inTlsKey != "") {
cert, err := tls.LoadX509KeyPair(n.inTlsCert, n.inTlsKey)
if err != nil {
starlog.Errorln("Load X509 Key Pair Failed:", err)
listener.Close()
return nil, err
}
tlsConfig.Certificates = []tls.Certificate{cert}
}
if n.inTlsAutoGen {
return starnet.ListenWithListener(listener, tlsConfig, n.autoGenCert, n.allowNoTls)
}
return starnet.ListenWithListener(listener, tlsConfig, nil, n.allowNoTls)
}
func (n *NetForward) autoGenCert(hostname string) *tls.Config {
if cert, ok := n.certCache[hostname]; ok {
return &tls.Config{Certificates: []tls.Certificate{cert}}
}
if n.toolCa == nil {
n.toolCa, n.toolCaKey = utils.ToolCert("")
}
cert, err := utils.GenerateTlsCert(utils.GenerateCertParams{
Country: "CN",
Organization: "B612 HTTP SERVER",
OrganizationUnit: "cert@b612.me",
CommonName: hostname,
Dns: []string{hostname},
KeyUsage: int(x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageCertSign),
ExtendedKeyUsage: []int{
int(x509.ExtKeyUsageServerAuth),
int(x509.ExtKeyUsageClientAuth),
},
IsCA: false,
StartDate: time.Now().Add(-24 * time.Hour),
EndDate: time.Now().AddDate(1, 0, 0),
Type: "RSA",
Bits: 2048,
CA: n.toolCa,
CAPriv: n.toolCaKey,
})
if err != nil {
return nil
}
n.certCache[hostname] = cert
return &tls.Config{Certificates: []tls.Certificate{cert}}
}
func (n *NetForward) runTCP() error {
atomic.AddInt32(&n.running, 1)
defer atomic.AddInt32(&n.running, -1)
listen, err := n.TcpListener()
if err != nil { if err != nil {
starlog.Errorln("Listening On Tcp Failed:", err) starlog.Errorln("Listening On Tcp Failed:", err)
return err return err
@ -167,7 +284,13 @@ func (n *NetForward) runTCP() error {
log.Infof("Delay %d ms\n", n.DelayMilSec) log.Infof("Delay %d ms\n", n.DelayMilSec)
time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec)) time.Sleep(time.Millisecond * time.Duration(n.DelayMilSec))
} }
err = SetTcpInfo(conn.(*net.TCPConn), n.UsingKeepAlive, n.KeepAliveIdel, n.KeepAlivePeriod, n.KeepAliveCount, n.UserTimeout) switch c := conn.(type) {
case *net.TCPConn:
err = SetTcpInfo(c, n.UsingKeepAlive, n.KeepAliveIdel, n.KeepAlivePeriod, n.KeepAliveCount, n.UserTimeout)
case *starnet.Conn:
err = SetTcpInfo(c.Conn.(*net.TCPConn), n.UsingKeepAlive, n.KeepAliveIdel, n.KeepAlivePeriod, n.KeepAliveCount, n.UserTimeout)
}
if err != nil { if err != nil {
log.Errorf("SetTcpInfo error for %s: %v\n", conn.RemoteAddr().String(), err) log.Errorf("SetTcpInfo error for %s: %v\n", conn.RemoteAddr().String(), err)
conn.Close() conn.Close()
@ -187,6 +310,24 @@ func (n *NetForward) runTCP() error {
return return
} }
log.Infof("TCP Connect %s <==> %s\n", conn.RemoteAddr().String(), rmt.RemoteAddr().String()) log.Infof("TCP Connect %s <==> %s\n", conn.RemoteAddr().String(), rmt.RemoteAddr().String())
if n.outTls {
serverName, _, _ := net.SplitHostPort(n.RemoteURI)
tlsConfig := &tls.Config{
InsecureSkipVerify: n.outTlsSkipVerify,
RootCAs: n.caPool,
ServerName: serverName,
}
if n.outTlsCert != "" && n.outTlsKey != "" {
tlsConfig.Certificates = []tls.Certificate{n.outTlsCertCache}
}
rmt = tls.Client(rmt, tlsConfig)
if err := rmt.(*tls.Conn).Handshake(); err != nil {
log.Errorf("TLS Handshake Failed: %v\n", err)
conn.Close()
rmt.Close()
return
}
}
n.copy(rmt, conn) n.copy(rmt, conn)
log.Noticef("TCP Connection Closed %s <==> %s\n", conn.RemoteAddr().String(), n.RemoteURI) log.Noticef("TCP Connection Closed %s <==> %s\n", conn.RemoteAddr().String(), n.RemoteURI)
conn.Close() conn.Close()

View File

@ -209,7 +209,7 @@ func run() {
return return
} }
} else { } else {
ca, cakey := utils.ToolCert() ca, cakey := utils.ToolCert("")
genCert, err := utils.GenerateTlsCert(utils.GenerateCertParams{ genCert, err := utils.GenerateTlsCert(utils.GenerateCertParams{
Country: "CN", Country: "CN",
Organization: "B612 SMTP SERVER", Organization: "B612 SMTP SERVER",

File diff suppressed because one or more lines are too long

View File

@ -77,10 +77,10 @@ func TestGenerateMiddleCA(t *testing.T) {
Locality: []string{"Asteroid B612"}, Locality: []string{"Asteroid B612"},
Organization: []string{"B612.ME"}, Organization: []string{"B612.ME"},
OrganizationalUnit: []string{"CA.B612.ME"}, OrganizationalUnit: []string{"CA.B612.ME"},
CommonName: "B612 Inter Tool CA", CommonName: "B612 Inter Tool CA 2025",
}, },
NotBefore: time.Date(2000, 01, 01, 8, 00, 00, 00, time.UTC), NotBefore: time.Date(2024, 01, 01, 8, 00, 00, 00, time.UTC),
NotAfter: time.Date(2077, 12, 31, 23, 59, 59, 00, time.UTC), NotAfter: time.Date(2026, 06, 12, 23, 59, 59, 00, time.UTC),
BasicConstraintsValid: true, BasicConstraintsValid: true,
IsCA: true, IsCA: true,
MaxPathLen: 0, MaxPathLen: 0,
@ -140,3 +140,25 @@ func LoadB612CA() (crypto.PrivateKey, *x509.Certificate, error) {
} }
return caKey, cert, nil return caKey, cert, nil
} }
func TestEncode(t *testing.T) {
crt, err := os.ReadFile("../bin/toolinter.crt")
if err != nil {
t.Fatal(err)
}
key, err := os.ReadFile("../bin/toolinter.key")
if err != nil {
t.Fatal(err)
}
aesKey := ``
encCrt, err := Encode(crt, aesKey)
if err != nil {
t.Fatal(err)
}
encKey, err := Encode(key, aesKey)
if err != nil {
t.Fatal(err)
}
fmt.Println("Encrypted Certificate:", hex.EncodeToString(encCrt))
fmt.Println("Encrypted Key:", hex.EncodeToString(encKey))
}