Move common code to an httputil package

This commit is contained in:
Richa Banker 2024-11-21 14:55:22 -08:00
parent 40f222b620
commit ebe5bab2cb
6 changed files with 139 additions and 116 deletions

View File

@ -23,11 +23,9 @@ import (
"math/rand"
"net/http"
"sort"
"strings"
"sync"
"github.com/munnerz/goautoneg"
"k8s.io/component-base/zpages/httputil"
"k8s.io/klog/v2"
)
@ -40,8 +38,7 @@ Warning: This endpoint is not meant to be machine parseable, has no formatting c
)
var (
flagzSeparators = []string{":", ": ", "=", " "}
errUnsupportedMediaType = fmt.Errorf("media type not acceptable, must be: text/plain")
delimiters = []string{":", ": ", "=", " "}
)
type registry struct {
@ -64,8 +61,8 @@ func (reg *registry) installHandler(m mux, componentName string, flagReader Read
func (reg *registry) handleFlags(componentName string, flagReader Reader) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !acceptableMediaType(r) {
http.Error(w, errUnsupportedMediaType.Error(), http.StatusNotAcceptable)
if !httputil.AcceptableMediaType(r) {
http.Error(w, httputil.ErrUnsupportedMediaType.Error(), http.StatusNotAcceptable)
return
}
@ -76,8 +73,8 @@ func (reg *registry) handleFlags(componentName string, flagReader Reader) http.H
return
}
randomIndex := rand.Intn(len(flagzSeparators))
separator := flagzSeparators[randomIndex]
randomIndex := rand.Intn(len(delimiters))
separator := delimiters[randomIndex]
// Randomize the delimiter for printing to prevent scraping of the response.
printSortedFlags(&reg.response, flagReader.GetFlagz(), separator)
})
@ -90,29 +87,6 @@ func (reg *registry) handleFlags(componentName string, flagReader Reader) http.H
}
}
func acceptableMediaType(r *http.Request) bool {
accepts := goautoneg.ParseAccept(r.Header.Get("Accept"))
for _, accept := range accepts {
if !mediaTypeMatches(accept) {
continue
}
if len(accept.Params) == 0 {
return true
}
if len(accept.Params) == 1 {
if charset, ok := accept.Params["charset"]; ok && strings.EqualFold(charset, "utf-8") {
return true
}
}
}
return false
}
func mediaTypeMatches(a goautoneg.Accept) bool {
return (a.Type == "text" || a.Type == "*") &&
(a.SubType == "plain" || a.SubType == "*")
}
func printSortedFlags(w io.Writer, flags map[string]string, separator string) {
var sortedKeys []string
for key := range flags {

View File

@ -35,7 +35,7 @@ Warning: This endpoint is not meant to be machine parseable, has no formatting c
func TestFlagz(t *testing.T) {
componentName := "test-server"
flagzSeparators = []string{"="}
delimiters = []string{"="}
wantHeaderLines := strings.Split(fmt.Sprintf(wantTmpl, componentName), "\n")
tests := []struct {
name string

View File

@ -0,0 +1,54 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package httputil
import (
"fmt"
"net/http"
"strings"
"github.com/munnerz/goautoneg"
)
// ErrUnsupportedMediaType is the error returned when the request's
// Accept header does not contain "text/plain".
var ErrUnsupportedMediaType = fmt.Errorf("media type not acceptable, must be: text/plain")
// AcceptableMediaType checks if the request's Accept header contains
// a supported media type with optional "charset=utf-8" parameter.
func AcceptableMediaType(r *http.Request) bool {
accepts := goautoneg.ParseAccept(r.Header.Get("Accept"))
for _, accept := range accepts {
if !mediaTypeMatches(accept) {
continue
}
if len(accept.Params) == 0 {
return true
}
if len(accept.Params) == 1 {
if charset, ok := accept.Params["charset"]; ok && strings.EqualFold(charset, "utf-8") {
return true
}
}
}
return false
}
func mediaTypeMatches(a goautoneg.Accept) bool {
return (a.Type == "text" || a.Type == "*") &&
(a.SubType == "plain" || a.SubType == "*")
}

View File

@ -0,0 +1,74 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package httputil
import (
"net/http"
"testing"
)
func TestAcceptableMediaTypes(t *testing.T) {
tests := []struct {
name string
reqHeader string
want bool
}{
{
name: "valid text/plain header",
reqHeader: "text/plain",
want: true,
},
{
name: "valid text/* header",
reqHeader: "text/*",
want: true,
},
{
name: "valid */plain header",
reqHeader: "*/plain",
want: true,
},
{
name: "valid accept args",
reqHeader: "text/plain; charset=utf-8",
want: true,
},
{
name: "invalid text/foo header",
reqHeader: "text/foo",
want: false,
},
{
name: "invalid text/plain params",
reqHeader: "text/plain; foo=bar",
want: false,
},
}
for _, tt := range tests {
req, err := http.NewRequest(http.MethodGet, "http://example.com/statusz", nil)
if err != nil {
t.Fatalf("Unexpected error while creating request: %v", err)
}
req.Header.Set("Accept", tt.reqHeader)
got := AcceptableMediaType(req)
if got != tt.want {
t.Errorf("Unexpected response from AcceptableMediaType(), want %v, got = %v", tt.want, got)
}
}
}

View File

@ -22,17 +22,14 @@ import (
"html/template"
"math/rand"
"net/http"
"strings"
"time"
"github.com/munnerz/goautoneg"
"k8s.io/component-base/zpages/httputil"
"k8s.io/klog/v2"
)
var (
delimiters = []string{":", ": ", "=", " "}
errUnsupportedMediaType = fmt.Errorf("media type not acceptable, must be: text/plain")
delimiters = []string{":", ": ", "=", " "}
)
const (
@ -88,8 +85,8 @@ func initializeTemplates() (*template.Template, error) {
func handleStatusz(componentName string, dataTmpl *template.Template, reg statuszRegistry) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !acceptableMediaType(r) {
http.Error(w, errUnsupportedMediaType.Error(), http.StatusNotAcceptable)
if !httputil.AcceptableMediaType(r) {
http.Error(w, httputil.ErrUnsupportedMediaType.Error(), http.StatusNotAcceptable)
return
}
@ -106,30 +103,6 @@ func handleStatusz(componentName string, dataTmpl *template.Template, reg status
}
}
// TODO(richabanker) : Move this to a common place to be reused for all zpages.
func acceptableMediaType(r *http.Request) bool {
accepts := goautoneg.ParseAccept(r.Header.Get("Accept"))
for _, accept := range accepts {
if !mediaTypeMatches(accept) {
continue
}
if len(accept.Params) == 0 {
return true
}
if len(accept.Params) == 1 {
if charset, ok := accept.Params["charset"]; ok && strings.EqualFold(charset, "utf-8") {
return true
}
}
}
return false
}
func mediaTypeMatches(a goautoneg.Accept) bool {
return (a.Type == "text" || a.Type == "*") &&
(a.SubType == "plain" || a.SubType == "*")
}
func populateStatuszData(tmpl *template.Template, reg statuszRegistry) (string, error) {
if tmpl == nil {
return "", fmt.Errorf("received nil template")

View File

@ -152,58 +152,6 @@ func TestStatusz(t *testing.T) {
}
}
func TestAcceptableMediaTypes(t *testing.T) {
tests := []struct {
name string
reqHeader string
want bool
}{
{
name: "valid text/plain header",
reqHeader: "text/plain",
want: true,
},
{
name: "valid text/* header",
reqHeader: "text/*",
want: true,
},
{
name: "valid */plain header",
reqHeader: "*/plain",
want: true,
},
{
name: "valid accept args",
reqHeader: "text/plain; charset=utf-8",
want: true,
},
{
name: "invalid text/foo header",
reqHeader: "text/foo",
want: false,
},
{
name: "invalid text/plain params",
reqHeader: "text/plain; foo=bar",
want: false,
},
}
for _, tt := range tests {
req, err := http.NewRequest(http.MethodGet, "http://example.com/statusz", nil)
if err != nil {
t.Fatalf("Unexpected error while creating request: %v", err)
}
req.Header.Set("Accept", tt.reqHeader)
got := acceptableMediaType(req)
if got != tt.want {
t.Errorf("Unexpected response from acceptableMediaType(), want %v, got = %v", tt.want, got)
}
}
}
func parseVersion(t *testing.T, v string) *version.Version {
parsed, err := version.ParseMajorMinor(v)
if err != nil {