From 9fc353211f9c69d6818c9e496ecf19279afa1ac2 Mon Sep 17 00:00:00 2001 From: starainrt Date: Sun, 12 May 2024 16:32:13 +0800 Subject: [PATCH] add ip filter for httpreverseproxy --- go.mod | 2 +- httpreverse/cfg.ini | 7 +- httpreverse/cmd.go | 2 +- httpreverse/reverse.go | 111 +++-- httpreverse/reverse_test.go | 9 + httpreverse/rp/reverseproxy.go | 852 +++++++++++++++++++++++++++++++++ httpreverse/service.go | 102 ++-- 7 files changed, 1012 insertions(+), 73 deletions(-) create mode 100644 httpreverse/rp/reverseproxy.go diff --git a/go.mod b/go.mod index 6648ec5..87787cd 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/spf13/cobra v1.8.0 github.com/things-go/go-socks5 v0.0.5 golang.org/x/crypto v0.21.0 + golang.org/x/net v0.21.0 software.sslmate.com/src/go-pkcs12 v0.4.0 ) @@ -39,7 +40,6 @@ require ( github.com/spf13/pflag v1.0.5 // indirect golang.org/x/image v0.6.0 // indirect golang.org/x/mod v0.14.0 // indirect - golang.org/x/net v0.21.0 // indirect golang.org/x/sys v0.18.0 // indirect golang.org/x/term v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect diff --git a/httpreverse/cfg.ini b/httpreverse/cfg.ini index 54a0d27..9052712 100644 --- a/httpreverse/cfg.ini +++ b/httpreverse/cfg.ini @@ -12,4 +12,9 @@ authuser=b612 authpasswd=b612 whiteip= blackip= -wanringpage= \ No newline at end of file +wanringpage= +ipfiltermode=3 +filterxforward= +filterremoteaddr= +filtermustkey= +filterfile= \ No newline at end of file diff --git a/httpreverse/cmd.go b/httpreverse/cmd.go index bac1e3f..f757e2f 100644 --- a/httpreverse/cmd.go +++ b/httpreverse/cmd.go @@ -75,7 +75,7 @@ var Cmd = &cobra.Command{ SkipSSLVerify: skipsslverify, Key: key, Cert: cert, - XForwardMode: 1, + IPFilterMode: 1, } go func() { sig := make(chan os.Signal) diff --git a/httpreverse/reverse.go b/httpreverse/reverse.go index 075c781..1a79462 100644 --- a/httpreverse/reverse.go +++ b/httpreverse/reverse.go @@ -1,35 +1,46 @@ package httpreverse import ( + "b612.me/apps/b612/httpreverse/rp" + "b612.me/starlog" "b612.me/staros/sysconf" + "bufio" "errors" + "io" "io/ioutil" + "net" "net/http" - "net/http/httputil" "net/url" + "os" "strings" "sync" ) type ReverseConfig struct { - Name string - Addr string - ReverseURL map[string]*url.URL - Port int - UsingSSL bool - Key string - Cert string - Host string - SkipSSLVerify bool - InHeader [][2]string - OutHeader [][2]string - Cookie [][3]string //[3]string should contains path::key::value - ReplaceList [][2]string - ReplaceOnce bool - proxy map[string]*httputil.ReverseProxy - XForwardMode int //0=off 1=useremote 2=add - httpmux http.ServeMux - httpserver http.Server + Name string + Addr string + ReverseURL map[string]*url.URL + Port int + UsingSSL bool + Key string + Cert string + Host string + SkipSSLVerify bool + InHeader [][2]string + OutHeader [][2]string + Cookie [][3]string //[3]string should contains path::key::value + ReplaceList [][2]string + ReplaceOnce bool + proxy map[string]*rp.ReverseProxy + IPFilterMode int //0=off 1=useremote 2=add 3=filter + FilterXForward bool + FilterRemoteAddr bool + FilterMustKey string + FilterSetKey string + FilterFile string + httpmux http.ServeMux + httpserver http.Server + CIDR []*net.IPNet basicAuthUser string basicAuthPwd string @@ -53,18 +64,54 @@ func Parse(path string) (HttpReverseServer, error) { } for _, v := range ini.Data { var ins = ReverseConfig{ - Name: v.Name, - Host: v.Get("host"), - Addr: v.Get("addr"), - Port: v.Int("port"), - UsingSSL: v.Bool("enablessl"), - Key: v.Get("key"), - Cert: v.Get("cert"), - ReplaceOnce: v.Bool("replaceonce"), - XForwardMode: v.Int("xforwardmode"), - basicAuthUser: v.Get("authuser"), - basicAuthPwd: v.Get("authpasswd"), - warningpage: v.Get("warnpage"), + Name: v.Name, + Host: v.Get("host"), + Addr: v.Get("addr"), + Port: v.Int("port"), + UsingSSL: v.Bool("enablessl"), + Key: v.Get("key"), + Cert: v.Get("cert"), + ReplaceOnce: v.Bool("replaceonce"), + IPFilterMode: v.Int("ipfiltermode"), + FilterXForward: v.Bool("filterxforward"), + FilterRemoteAddr: v.Bool("filterremoteaddr"), + FilterMustKey: v.Get("filtermustkey"), + FilterSetKey: v.Get("filtersetkey"), + FilterFile: v.Get("filterfile"), + basicAuthUser: v.Get("authuser"), + basicAuthPwd: v.Get("authpasswd"), + warningpage: v.Get("warnpage"), + } + if ins.IPFilterMode == 3 && ins.FilterFile != "" { + starlog.Infoln("IP Filter Mode 3, Load IP Filter File", ins.FilterFile) + f, err := os.Open(ins.FilterFile) + if err != nil { + return res, err + } + buf := bufio.NewReader(f) + count := 0 + for { + line, err := buf.ReadString('\n') + if err != nil { + if err == io.EOF { + f.Close() + break + } + f.Close() + return res, err + } + line = strings.TrimSpace(line) + if !strings.Contains(line, "/") { + line += "/32" //todo:区分IPV6 + } + _, cidr, err := net.ParseCIDR(line) + if err != nil { + return res, err + } + ins.CIDR = append(ins.CIDR, cidr) + count++ + } + starlog.Infoln("Load", count, "CIDR") } if ins.warningpage != "" { data, err := ioutil.ReadFile(ins.warningpage) @@ -73,7 +120,7 @@ func Parse(path string) (HttpReverseServer, error) { } ins.warnpagedata = data } - ins.proxy = make(map[string]*httputil.ReverseProxy) + ins.proxy = make(map[string]*rp.ReverseProxy) ins.ReverseURL = make(map[string]*url.URL) for _, reverse := range v.GetAll("reverse") { kv := strings.SplitN(reverse, "::", 2) diff --git a/httpreverse/reverse_test.go b/httpreverse/reverse_test.go index a72f561..02e9c17 100644 --- a/httpreverse/reverse_test.go +++ b/httpreverse/reverse_test.go @@ -2,9 +2,18 @@ package httpreverse import ( "fmt" + "net" "testing" ) +func TestCIDR(t *testing.T) { + _, c, err := net.ParseCIDR("108.162.192.0/18") + if err != nil { + t.Fatal(err) + } + fmt.Println(c.Contains(net.ParseIP("108.162.245.124"))) +} + func TestReverseParse(t *testing.T) { data, err := Parse("./cfg.ini") if err != nil { diff --git a/httpreverse/rp/reverseproxy.go b/httpreverse/rp/reverseproxy.go new file mode 100644 index 0000000..02f0172 --- /dev/null +++ b/httpreverse/rp/reverseproxy.go @@ -0,0 +1,852 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP reverse proxy handler + +package rp + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "mime" + "net" + "net/http" + "net/http/httptrace" + "net/textproto" + "net/url" + "strings" + "sync" + "time" + + "golang.org/x/net/http/httpguts" +) + +func lower(b byte) byte { + if 'A' <= b && b <= 'Z' { + return b + ('a' - 'A') + } + return b +} + +func EqualFold(s, t string) bool { + if len(s) != len(t) { + return false + } + for i := 0; i < len(s); i++ { + if lower(s[i]) != lower(t[i]) { + return false + } + } + return true +} + +func IsPrint(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] < ' ' || s[i] > '~' { + return false + } + } + return true +} + +// A ProxyRequest contains a request to be rewritten by a ReverseProxy. +type ProxyRequest struct { + // In is the request received by the proxy. + // The Rewrite function must not modify In. + In *http.Request + + // Out is the request which will be sent by the proxy. + // The Rewrite function may modify or replace this request. + // Hop-by-hop headers are removed from this request + // before Rewrite is called. + Out *http.Request +} + +// SetURL routes the outbound request to the scheme, host, and base path +// provided in target. If the target's path is "/base" and the incoming +// request was for "/dir", the target request will be for "/base/dir". +// +// SetURL rewrites the outbound Host header to match the target's host. +// To preserve the inbound request's Host header (the default behavior +// of NewSingleHostReverseProxy): +// +// rewriteFunc := func(r *httputil.ProxyRequest) { +// r.SetURL(url) +// r.Out.Host = r.In.Host +// } +func (r *ProxyRequest) SetURL(target *url.URL) { + rewriteRequestURL(r.Out, target) + r.Out.Host = "" +} + +// SetXForwarded sets the X-Forwarded-For, X-Forwarded-Host, and +// X-Forwarded-Proto headers of the outbound request. +// +// - The X-Forwarded-For header is set to the client IP address. +// - The X-Forwarded-Host header is set to the host name requested +// by the client. +// - The X-Forwarded-Proto header is set to "http" or "https", depending +// on whether the inbound request was made on a TLS-enabled connection. +// +// If the outbound request contains an existing X-Forwarded-For header, +// SetXForwarded appends the client IP address to it. To append to the +// inbound request's X-Forwarded-For header (the default behavior of +// ReverseProxy when using a Director function), copy the header +// from the inbound request before calling SetXForwarded: +// +// rewriteFunc := func(r *httputil.ProxyRequest) { +// r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"] +// r.SetXForwarded() +// } +func (r *ProxyRequest) SetXForwarded() { + clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr) + if err == nil { + prior := r.Out.Header["X-Forwarded-For"] + if len(prior) > 0 { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + r.Out.Header.Set("X-Forwarded-For", clientIP) + } else { + r.Out.Header.Del("X-Forwarded-For") + } + r.Out.Header.Set("X-Forwarded-Host", r.In.Host) + if r.In.TLS == nil { + r.Out.Header.Set("X-Forwarded-Proto", "http") + } else { + r.Out.Header.Set("X-Forwarded-Proto", "https") + } +} + +// ReverseProxy is an HTTP Handler that takes an incoming request and +// sends it to another server, proxying the response back to the +// client. +// +// 1xx responses are forwarded to the client if the underlying +// transport supports ClientTrace.Got1xxResponse. +type ReverseProxy struct { + // Rewrite must be a function which modifies + // the request into a new request to be sent + // using Transport. Its response is then copied + // back to the original client unmodified. + // Rewrite must not access the provided ProxyRequest + // or its contents after returning. + // + // The Forwarded, X-Forwarded, X-Forwarded-Host, + // and X-Forwarded-Proto headers are removed from the + // outbound request before Rewrite is called. See also + // the ProxyRequest.SetXForwarded method. + // + // Unparsable query parameters are removed from the + // outbound request before Rewrite is called. + // The Rewrite function may copy the inbound URL's + // RawQuery to the outbound URL to preserve the original + // parameter string. Note that this can lead to security + // issues if the proxy's interpretation of query parameters + // does not match that of the downstream server. + // + // At most one of Rewrite or Director may be set. + Rewrite func(*ProxyRequest) + + // Director is a function which modifies + // the request into a new request to be sent + // using Transport. Its response is then copied + // back to the original client unmodified. + // Director must not access the provided Request + // after returning. + // + // By default, the X-Forwarded-For header is set to the + // value of the client IP address. If an X-Forwarded-For + // header already exists, the client IP is appended to the + // existing values. As a special case, if the header + // exists in the Request.Header map but has a nil value + // (such as when set by the Director func), the X-Forwarded-For + // header is not modified. + // + // To prevent IP spoofing, be sure to delete any pre-existing + // X-Forwarded-For header coming from the client or + // an untrusted proxy. + // + // Hop-by-hop headers are removed from the request after + // Director returns, which can remove headers added by + // Director. Use a Rewrite function instead to ensure + // modifications to the request are preserved. + // + // Unparsable query parameters are removed from the outbound + // request if Request.Form is set after Director returns. + // + // At most one of Rewrite or Director may be set. + Director func(*http.Request) + + // The transport used to perform proxy requests. + // If nil, http.DefaultTransport is used. + Transport http.RoundTripper + + // FlushInterval specifies the flush interval + // to flush to the client while copying the + // response body. + // If zero, no periodic flushing is done. + // A negative value means to flush immediately + // after each write to the client. + // The FlushInterval is ignored when ReverseProxy + // recognizes a response as a streaming response, or + // if its ContentLength is -1; for such responses, writes + // are flushed to the client immediately. + FlushInterval time.Duration + + // ErrorLog specifies an optional logger for errors + // that occur when attempting to proxy the request. + // If nil, logging is done via the log package's standard logger. + ErrorLog *log.Logger + + // BufferPool optionally specifies a buffer pool to + // get byte slices for use by io.CopyBuffer when + // copying HTTP response bodies. + BufferPool BufferPool + + // ModifyResponse is an optional function that modifies the + // Response from the backend. It is called if the backend + // returns a response at all, with any HTTP status code. + // If the backend is unreachable, the optional ErrorHandler is + // called without any call to ModifyResponse. + // + // If ModifyResponse returns an error, ErrorHandler is called + // with its error value. If ErrorHandler is nil, its default + // implementation is used. + ModifyResponse func(*http.Response) error + + // ErrorHandler is an optional function that handles errors + // reaching the backend or errors from ModifyResponse. + // + // If nil, the default is to log the provided error and return + // a 502 Status Bad Gateway response. + ErrorHandler func(http.ResponseWriter, *http.Request, error) +} + +// A BufferPool is an interface for getting and returning temporary +// byte slices for use by io.CopyBuffer. +type BufferPool interface { + Get() []byte + Put([]byte) +} + +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +func joinURLPath(a, b *url.URL) (path, rawpath string) { + if a.RawPath == "" && b.RawPath == "" { + return singleJoiningSlash(a.Path, b.Path), "" + } + // Same as singleJoiningSlash, but uses EscapedPath to determine + // whether a slash should be added + apath := a.EscapedPath() + bpath := b.EscapedPath() + + aslash := strings.HasSuffix(apath, "/") + bslash := strings.HasPrefix(bpath, "/") + + switch { + case aslash && bslash: + return a.Path + b.Path[1:], apath + bpath[1:] + case !aslash && !bslash: + return a.Path + "/" + b.Path, apath + "/" + bpath + } + return a.Path + b.Path, apath + bpath +} + +// NewSingleHostReverseProxy returns a new ReverseProxy that routes +// URLs to the scheme, host, and base path provided in target. If the +// target's path is "/base" and the incoming request was for "/dir", +// the target request will be for /base/dir. +// +// NewSingleHostReverseProxy does not rewrite the Host header. +// +// To customize the ReverseProxy behavior beyond what +// NewSingleHostReverseProxy provides, use ReverseProxy directly +// with a Rewrite function. The ProxyRequest SetURL method +// may be used to route the outbound request. (Note that SetURL, +// unlike NewSingleHostReverseProxy, rewrites the Host header +// of the outbound request by default.) +// +// proxy := &ReverseProxy{ +// Rewrite: func(r *ProxyRequest) { +// r.SetURL(target) +// r.Out.Host = r.In.Host // if desired +// } +// } +func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { + director := func(req *http.Request) { + rewriteRequestURL(req, target) + } + return &ReverseProxy{Director: director} +} + +func rewriteRequestURL(req *http.Request, target *url.URL) { + targetQuery := target.RawQuery + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL) + if targetQuery == "" || req.URL.RawQuery == "" { + req.URL.RawQuery = targetQuery + req.URL.RawQuery + } else { + req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery + } +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +// Hop-by-hop headers. These are removed when sent to the backend. +// As of RFC 7230, hop-by-hop headers are required to appear in the +// Connection header field. These are the headers defined by the +// obsoleted RFC 2616 (section 13.5.1) and are used for backward +// compatibility. +var hopHeaders = []string{ + "Connection", + "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 + "Transfer-Encoding", + "Upgrade", +} + +func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) { + p.logf("http: proxy error: %v", err) + rw.WriteHeader(http.StatusBadGateway) +} + +func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) { + if p.ErrorHandler != nil { + return p.ErrorHandler + } + return p.defaultErrorHandler +} + +// modifyResponse conditionally runs the optional ModifyResponse hook +// and reports whether the request should proceed. +func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool { + if p.ModifyResponse == nil { + return true + } + if err := p.ModifyResponse(res); err != nil { + res.Body.Close() + p.getErrorHandler()(rw, req, err) + return false + } + return true +} + +func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + transport := p.Transport + if transport == nil { + transport = http.DefaultTransport + } + + ctx := req.Context() + if ctx.Done() != nil { + // CloseNotifier predates context.Context, and has been + // entirely superseded by it. If the request contains + // a Context that carries a cancellation signal, don't + // bother spinning up a goroutine to watch the CloseNotify + // channel (if any). + // + // If the request Context has a nil Done channel (which + // means it is either context.Background, or a custom + // Context implementation with no cancellation signal), + // then consult the CloseNotifier if available. + } else if cn, ok := rw.(http.CloseNotifier); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + defer cancel() + notifyChan := cn.CloseNotify() + go func() { + select { + case <-notifyChan: + cancel() + case <-ctx.Done(): + } + }() + } + + outreq := req.Clone(ctx) + if req.ContentLength == 0 { + outreq.Body = nil // Issue 16036: nil Body for http.Transport retries + } + if outreq.Body != nil { + // Reading from the request body after returning from a handler is not + // allowed, and the RoundTrip goroutine that reads the Body can outlive + // this handler. This can lead to a crash if the handler panics (see + // Issue 46866). Although calling Close doesn't guarantee there isn't + // any Read in flight after the handle returns, in practice it's safe to + // read after closing it. + defer outreq.Body.Close() + } + if outreq.Header == nil { + outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate + } + + if (p.Director != nil) == (p.Rewrite != nil) { + p.getErrorHandler()(rw, req, errors.New("ReverseProxy must have exactly one of Director or Rewrite set")) + return + } + + if p.Director != nil { + p.Director(outreq) + if outreq.Form != nil { + outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery) + } + } + outreq.Close = false + + reqUpType := upgradeType(outreq.Header) + if !IsPrint(reqUpType) { + p.getErrorHandler()(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType)) + return + } + removeHopByHopHeaders(outreq.Header) + + // Issue 21096: tell backend applications that care about trailer support + // that we support trailers. (We do, but we don't go out of our way to + // advertise that unless the incoming client request thought it was worth + // mentioning.) Note that we look at req.Header, not outreq.Header, since + // the latter has passed through removeHopByHopHeaders. + if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") { + outreq.Header.Set("Te", "trailers") + } + + // After stripping all the hop-by-hop connection headers above, add back any + // necessary for protocol upgrades, such as for websockets. + if reqUpType != "" { + outreq.Header.Set("Connection", "Upgrade") + outreq.Header.Set("Upgrade", reqUpType) + } + + if p.Rewrite != nil { + // Strip client-provided forwarding headers. + // The Rewrite func may use SetXForwarded to set new values + // for these or copy the previous values from the inbound request. + outreq.Header.Del("Forwarded") + outreq.Header.Del("X-Forwarded-For") + outreq.Header.Del("X-Forwarded-Host") + outreq.Header.Del("X-Forwarded-Proto") + + // Remove unparsable query parameters from the outbound request. + outreq.URL.RawQuery = cleanQueryParams(outreq.URL.RawQuery) + + pr := &ProxyRequest{ + In: req, + Out: outreq, + } + p.Rewrite(pr) + outreq = pr.Out + } + + if _, ok := outreq.Header["User-Agent"]; !ok { + // If the outbound request doesn't have a User-Agent header set, + // don't send the default Go HTTP client User-Agent. + outreq.Header.Set("User-Agent", "") + } + + trace := &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + h := rw.Header() + copyHeader(h, http.Header(header)) + rw.WriteHeader(code) + + // Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses + for k := range h { + delete(h, k) + } + + return nil + }, + } + outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace)) + + res, err := transport.RoundTrip(outreq) + if err != nil { + p.getErrorHandler()(rw, outreq, err) + return + } + + // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) + if res.StatusCode == http.StatusSwitchingProtocols { + if !p.modifyResponse(rw, res, outreq) { + return + } + p.handleUpgradeResponse(rw, outreq, res) + return + } + + removeHopByHopHeaders(res.Header) + + if !p.modifyResponse(rw, res, outreq) { + return + } + + copyHeader(rw.Header(), res.Header) + + // The "Trailer" header isn't included in the Transport's response, + // at least for *http.Transport. Build it up from Trailer. + announcedTrailers := len(res.Trailer) + if announcedTrailers > 0 { + trailerKeys := make([]string, 0, len(res.Trailer)) + for k := range res.Trailer { + trailerKeys = append(trailerKeys, k) + } + rw.Header().Add("Trailer", strings.Join(trailerKeys, ", ")) + } + + rw.WriteHeader(res.StatusCode) + + err = p.copyResponse(rw, res.Body, p.flushInterval(res)) + if err != nil { + defer res.Body.Close() + // Since we're streaming the response, if we run into an error all we can do + // is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler + // on read error while copying body. + if !shouldPanicOnCopyError(req) { + p.logf("suppressing panic for copyResponse error in test; copy error: %v", err) + return + } + panic(http.ErrAbortHandler) + } + res.Body.Close() // close now, instead of defer, to populate res.Trailer + + if len(res.Trailer) > 0 { + // Force chunking if we saw a response trailer. + // This prevents net/http from calculating the length for short + // bodies and adding a Content-Length. + if fl, ok := rw.(http.Flusher); ok { + fl.Flush() + } + } + + if len(res.Trailer) == announcedTrailers { + copyHeader(rw.Header(), res.Trailer) + return + } + + for k, vv := range res.Trailer { + k = http.TrailerPrefix + k + for _, v := range vv { + rw.Header().Add(k, v) + } + } +} + +var inOurTests bool // whether we're in our own tests + +// shouldPanicOnCopyError reports whether the reverse proxy should +// panic with http.ErrAbortHandler. This is the right thing to do by +// default, but Go 1.10 and earlier did not, so existing unit tests +// weren't expecting panics. Only panic in our own tests, or when +// running under the HTTP server. +func shouldPanicOnCopyError(req *http.Request) bool { + if inOurTests { + // Our tests know to handle this panic. + return true + } + if req.Context().Value(http.ServerContextKey) != nil { + // We seem to be running under an HTTP server, so + // it'll recover the panic. + return true + } + // Otherwise act like Go 1.10 and earlier to not break + // existing tests. + return false +} + +// removeHopByHopHeaders removes hop-by-hop headers. +func removeHopByHopHeaders(h http.Header) { + // RFC 7230, section 6.1: Remove headers listed in the "Connection" header. + for _, f := range h["Connection"] { + for _, sf := range strings.Split(f, ",") { + if sf = textproto.TrimString(sf); sf != "" { + h.Del(sf) + } + } + } + // RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers. + // This behavior is superseded by the RFC 7230 Connection header, but + // preserve it for backwards compatibility. + for _, f := range hopHeaders { + h.Del(f) + } +} + +// flushInterval returns the p.FlushInterval value, conditionally +// overriding its value for a specific request/response. +func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration { + resCT := res.Header.Get("Content-Type") + + // For Server-Sent Events responses, flush immediately. + // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream + if baseCT, _, _ := mime.ParseMediaType(resCT); baseCT == "text/event-stream" { + return -1 // negative means immediately + } + + // We might have the case of streaming for which Content-Length might be unset. + if res.ContentLength == -1 { + return -1 + } + + return p.FlushInterval +} + +func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error { + if flushInterval != 0 { + if wf, ok := dst.(writeFlusher); ok { + mlw := &maxLatencyWriter{ + dst: wf, + latency: flushInterval, + } + defer mlw.stop() + + // set up initial timer so headers get flushed even if body writes are delayed + mlw.flushPending = true + mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) + + dst = mlw + } + } + + var buf []byte + if p.BufferPool != nil { + buf = p.BufferPool.Get() + defer p.BufferPool.Put(buf) + } + _, err := p.copyBuffer(dst, src, buf) + return err +} + +// copyBuffer returns any write errors or non-EOF read errors, and the amount +// of bytes written. +func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { + if len(buf) == 0 { + buf = make([]byte, 32*1024) + } + var written int64 + for { + nr, rerr := src.Read(buf) + if rerr != nil && rerr != io.EOF && rerr != context.Canceled { + p.logf("httputil: ReverseProxy read error during body copy: %v", rerr) + } + if nr > 0 { + nw, werr := dst.Write(buf[:nr]) + if nw > 0 { + written += int64(nw) + } + if werr != nil { + return written, werr + } + if nr != nw { + return written, io.ErrShortWrite + } + } + if rerr != nil { + if rerr == io.EOF { + rerr = nil + } + return written, rerr + } + } +} + +func (p *ReverseProxy) logf(format string, args ...any) { + if p.ErrorLog != nil { + p.ErrorLog.Printf(format, args...) + } else { + log.Printf(format, args...) + } +} + +type writeFlusher interface { + io.Writer + http.Flusher +} + +type maxLatencyWriter struct { + dst writeFlusher + latency time.Duration // non-zero; negative means to flush immediately + + mu sync.Mutex // protects t, flushPending, and dst.Flush + t *time.Timer + flushPending bool +} + +func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { + m.mu.Lock() + defer m.mu.Unlock() + n, err = m.dst.Write(p) + if m.latency < 0 { + m.dst.Flush() + return + } + if m.flushPending { + return + } + if m.t == nil { + m.t = time.AfterFunc(m.latency, m.delayedFlush) + } else { + m.t.Reset(m.latency) + } + m.flushPending = true + return +} + +func (m *maxLatencyWriter) delayedFlush() { + m.mu.Lock() + defer m.mu.Unlock() + if !m.flushPending { // if stop was called but AfterFunc already started this goroutine + return + } + m.dst.Flush() + m.flushPending = false +} + +func (m *maxLatencyWriter) stop() { + m.mu.Lock() + defer m.mu.Unlock() + m.flushPending = false + if m.t != nil { + m.t.Stop() + } +} + +func upgradeType(h http.Header) string { + if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") { + return "" + } + return h.Get("Upgrade") +} + +func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { + reqUpType := upgradeType(req.Header) + resUpType := upgradeType(res.Header) + if !IsPrint(resUpType) { // We know reqUpType is ASCII, it's checked by the caller. + p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch to invalid protocol %q", resUpType)) + } + if !EqualFold(reqUpType, resUpType) { + p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) + return + } + + hj, ok := rw.(http.Hijacker) + if !ok { + p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) + return + } + backConn, ok := res.Body.(io.ReadWriteCloser) + if !ok { + p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) + return + } + + backConnCloseCh := make(chan bool) + go func() { + // Ensure that the cancellation of a request closes the backend. + // See issue https://golang.org/issue/35559. + select { + case <-req.Context().Done(): + case <-backConnCloseCh: + } + backConn.Close() + }() + + defer close(backConnCloseCh) + + conn, brw, err := hj.Hijack() + if err != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err)) + return + } + defer conn.Close() + + copyHeader(rw.Header(), res.Header) + + res.Header = rw.Header() + res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above + if err := res.Write(brw); err != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) + return + } + if err := brw.Flush(); err != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err)) + return + } + errc := make(chan error, 1) + spc := switchProtocolCopier{user: conn, backend: backConn} + go spc.copyToBackend(errc) + go spc.copyFromBackend(errc) + <-errc +} + +// switchProtocolCopier exists so goroutines proxying data back and +// forth have nice names in stacks. +type switchProtocolCopier struct { + user, backend io.ReadWriter +} + +func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { + _, err := io.Copy(c.user, c.backend) + errc <- err +} + +func (c switchProtocolCopier) copyToBackend(errc chan<- error) { + _, err := io.Copy(c.backend, c.user) + errc <- err +} + +func cleanQueryParams(s string) string { + reencode := func(s string) string { + v, _ := url.ParseQuery(s) + return v.Encode() + } + for i := 0; i < len(s); { + switch s[i] { + case ';': + return reencode(s) + case '%': + if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) { + return reencode(s) + } + i += 3 + default: + i++ + } + } + return s +} + +func ishex(c byte) bool { + switch { + case '0' <= c && c <= '9': + return true + case 'a' <= c && c <= 'f': + return true + case 'A' <= c && c <= 'F': + return true + } + return false +} diff --git a/httpreverse/service.go b/httpreverse/service.go index 70b152b..1d58140 100644 --- a/httpreverse/service.go +++ b/httpreverse/service.go @@ -1,6 +1,7 @@ package httpreverse import ( + "b612.me/apps/b612/httpreverse/rp" "b612.me/starlog" "bytes" "context" @@ -10,7 +11,6 @@ import ( "io/ioutil" "net" "net/http" - "net/http/httputil" "net/url" "strconv" "strings" @@ -111,31 +111,27 @@ func (h *ReverseConfig) dialTLS(ctx context.Context, network, addr string) (net. } func (h *ReverseConfig) init() error { - h.proxy = make(map[string]*httputil.ReverseProxy) + h.proxy = make(map[string]*rp.ReverseProxy) for key, val := range h.ReverseURL { - h.proxy[key] = &httputil.ReverseProxy{ + h.proxy[key] = &rp.ReverseProxy{ Transport: &http.Transport{DialTLSContext: h.dialTLS}, - Director: func(req *http.Request) { - targetQuery := val.RawQuery - req.URL.Scheme = val.Scheme - if h.Host == "" { - req.Host = val.Host - } else { - req.Host = h.Host - } - req.URL.Host = val.Host - req.URL.Path, req.URL.RawPath = joinURLPath(val, req.URL) - if targetQuery == "" || req.URL.RawQuery == "" { - req.URL.RawQuery = targetQuery + req.URL.RawQuery - } else { - req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery - } - }, } h.proxy[key].ModifyResponse = h.ModifyResponse() - originalDirector := h.proxy[key].Director h.proxy[key].Director = func(req *http.Request) { - originalDirector(req) + targetQuery := val.RawQuery + req.URL.Scheme = val.Scheme + if h.Host == "" { + req.Host = val.Host + } else { + req.Host = h.Host + } + req.URL.Host = val.Host + req.URL.Path, req.URL.RawPath = joinURLPath(val, req.URL, key) + if targetQuery == "" || req.URL.RawQuery == "" { + req.URL.RawQuery = targetQuery + req.URL.RawQuery + } else { + req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery + } h.ModifyRequest(req, val) } } @@ -181,13 +177,54 @@ func (h *ReverseConfig) ModifyResponse() func(*http.Response) error { } } +func (h *ReverseConfig) isInCIDR(ip string) bool { + nip := net.ParseIP(strings.TrimSpace(ip)) + if nip == nil { + return false + } + for _, c := range h.CIDR { + if c.Contains(nip) { + return true + } + } + return false +} + func (h *ReverseConfig) ModifyRequest(req *http.Request, remote *url.URL) { - if h.XForwardMode == 1 { + switch h.IPFilterMode { + case 1: req.Header.Set("X-Forwarded-For", strings.Split(req.RemoteAddr, ":")[0]) - } else if h.XForwardMode == 2 { + case 2: xforward := strings.Split(strings.TrimSpace(req.Header.Get("X-Forwarded-For")), ",") xforward = append(xforward, strings.Split(req.RemoteAddr, ":")[0]) req.Header.Set("X-Forwarded-For", strings.Join(xforward, ", ")) + case 3: + var lastForwardIP string + var xforward []string + if h.FilterMustKey != "" && req.Header.Get(h.FilterMustKey) != "" { + lastForwardIP = req.Header.Get(h.FilterMustKey) + xforward = []string{lastForwardIP} + } else { + for _, ip := range append(strings.Split(strings.TrimSpace(req.Header.Get("X-Forwarded-For")), ","), strings.Split(req.RemoteAddr, ":")[0]) { + ip = strings.TrimSpace(ip) + if !h.isInCIDR(ip) { + xforward = append(xforward, ip) + lastForwardIP = ip + } + } + } + if lastForwardIP == "" { + lastForwardIP = strings.Split(req.RemoteAddr, ":")[0] + } + if h.FilterXForward { + req.Header.Set("X-Forwarded-For", strings.Join(xforward, ", ")) + } + if h.FilterRemoteAddr { + req.Header.Set("X-Real-IP", lastForwardIP) + } + if h.FilterSetKey != "" { + req.Header.Set(h.FilterSetKey, lastForwardIP) + } } for _, v := range h.Cookie { req.AddCookie(&http.Cookie{ @@ -196,20 +233,6 @@ func (h *ReverseConfig) ModifyRequest(req *http.Request, remote *url.URL) { Path: v[0], }) } - host := h.Host - if host == "" { - host = remote.Host - } - targetQuery := remote.RawQuery - req.URL.Scheme = remote.Scheme - req.URL.Host = remote.Host - req.Host = host - req.URL.Path, req.URL.RawPath = joinURLPath(remote, req.URL) - if targetQuery == "" || req.URL.RawQuery == "" { - req.URL.RawQuery = targetQuery + req.URL.RawQuery - } else { - req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery - } for _, v := range h.InHeader { req.Header.Set(v[0], v[1]) } @@ -281,12 +304,15 @@ func (h *ReverseConfig) filter(w http.ResponseWriter, r *http.Request) bool { return true } -func joinURLPath(a, b *url.URL) (path, rawpath string) { +func joinURLPath(a, b *url.URL, hpath string) (path, rawpath string) { + b.Path = strings.TrimPrefix(b.Path, hpath) + b.RawPath = strings.TrimPrefix(b.RawPath, hpath) if a.RawPath == "" && b.RawPath == "" { return singleJoiningSlash(a.Path, b.Path), "" } // Same as singleJoiningSlash, but uses EscapedPath to determine // whether a slash should be added + apath := a.EscapedPath() bpath := b.EscapedPath()