2020-02-20 23:06:43 +00:00
|
|
|
package urlbuilder
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
|
|
|
"net/http"
|
|
|
|
"net/url"
|
|
|
|
"strings"
|
|
|
|
)
|
|
|
|
|
|
|
|
func RedirectRewrite(next http.Handler) http.Handler {
|
|
|
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
|
|
|
prefix := req.Header.Get(PrefixHeader)
|
|
|
|
if prefix == "" {
|
|
|
|
next.ServeHTTP(rw, req)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
r := &redirector{
|
|
|
|
ResponseWriter: rw,
|
|
|
|
prefix: prefix,
|
|
|
|
}
|
2020-02-28 16:30:25 +00:00
|
|
|
if h, ok := rw.(http.Hijacker); ok {
|
|
|
|
r.Hijacker = h
|
|
|
|
}
|
2020-02-20 23:06:43 +00:00
|
|
|
next.ServeHTTP(r, req)
|
|
|
|
r.Close()
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
type redirector struct {
|
|
|
|
http.ResponseWriter
|
2020-02-28 16:30:25 +00:00
|
|
|
http.Hijacker
|
2020-02-20 23:06:43 +00:00
|
|
|
prefix string
|
|
|
|
from, to string
|
|
|
|
tempBuffer *bytes.Buffer
|
|
|
|
}
|
|
|
|
|
|
|
|
func (r *redirector) Write(content []byte) (int, error) {
|
|
|
|
if r.tempBuffer == nil {
|
|
|
|
return r.ResponseWriter.Write(content)
|
|
|
|
}
|
|
|
|
return r.tempBuffer.Write(content)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (r *redirector) Close() error {
|
|
|
|
if r.tempBuffer == nil || r.from == "" || r.to == "" {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
content := bytes.Replace(r.tempBuffer.Bytes(), []byte(r.from), []byte(r.to), -1)
|
|
|
|
_, err := r.ResponseWriter.Write(content)
|
|
|
|
r.tempBuffer = nil
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
func (r *redirector) WriteHeader(statusCode int) {
|
|
|
|
defer func() {
|
|
|
|
// the anonymous func is so that we take the new value of statusCode,
|
|
|
|
// not copy it at invocation
|
|
|
|
r.ResponseWriter.WriteHeader(statusCode)
|
|
|
|
}()
|
|
|
|
|
|
|
|
if statusCode != http.StatusMovedPermanently && statusCode != http.StatusFound {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
l := r.Header().Get("Location")
|
|
|
|
if l == "" {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
u, _ := url.Parse(l)
|
|
|
|
if !strings.HasPrefix(u.Path, r.prefix) {
|
|
|
|
r.from = u.Path
|
|
|
|
u.Path = r.prefix + u.Path
|
|
|
|
r.Header().Set("Location", u.String())
|
|
|
|
r.to = u.Path
|
|
|
|
r.tempBuffer = &bytes.Buffer{}
|
|
|
|
}
|
|
|
|
|
|
|
|
statusCode = http.StatusFound
|
|
|
|
}
|