Merge pull request #126187 from seans3/portforward-websockets-metrics

Adds metrics to PortForward Websockets
This commit is contained in:
Kubernetes Prow Robot 2024-07-22 16:53:25 -07:00 committed by GitHub
commit 04cc0a1034
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 185 additions and 1 deletions

View File

@ -33,7 +33,7 @@ var registerMetricsOnce sync.Once
var (
// streamTranslatorRequestsTotal counts the number of requests that were handled by
// the StreamTranslatorProxy.
// the StreamTranslatorProxy (RemoteCommand subprotocol).
streamTranslatorRequestsTotal = metrics.NewCounterVec(
&metrics.CounterOpts{
Subsystem: subsystem,
@ -43,19 +43,37 @@ var (
},
[]string{statuscode},
)
// streamTunnelRequestsTotal counts the number of requests that were handled by
// the StreamTunnelProxy (PortForward subprotocol).
streamTunnelRequestsTotal = metrics.NewCounterVec(
&metrics.CounterOpts{
Subsystem: subsystem,
Name: "stream_tunnel_requests_total",
Help: "Total number of requests that were handled by the StreamTunnelProxy, which processes streaming PortForward/V2",
StabilityLevel: metrics.ALPHA,
},
[]string{statuscode},
)
)
func Register() {
registerMetricsOnce.Do(func() {
legacyregistry.MustRegister(streamTranslatorRequestsTotal)
legacyregistry.MustRegister(streamTunnelRequestsTotal)
})
}
func ResetForTest() {
streamTranslatorRequestsTotal.Reset()
streamTunnelRequestsTotal.Reset()
}
// IncStreamTranslatorRequest increments the # of requests handled by the StreamTranslatorProxy.
func IncStreamTranslatorRequest(ctx context.Context, status string) {
streamTranslatorRequestsTotal.WithContext(ctx).WithLabelValues(status).Add(1)
}
// IncStreamTunnelRequest increments the # of requests handled by the StreamTunnelProxy.
func IncStreamTunnelRequest(ctx context.Context, status string) {
streamTunnelRequestsTotal.WithContext(ctx).WithLabelValues(status).Add(1)
}

View File

@ -19,10 +19,12 @@ package proxy
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
@ -34,6 +36,7 @@ import (
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
utilnet "k8s.io/apimachinery/pkg/util/net"
constants "k8s.io/apimachinery/pkg/util/portforward"
"k8s.io/apiserver/pkg/util/proxy/metrics"
"k8s.io/client-go/tools/portforward"
"k8s.io/klog/v2"
)
@ -61,6 +64,7 @@ func (h *TunnelingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
spdyProtocols := spdyProtocolsFromWebsocketProtocols(req)
if len(spdyProtocols) == 0 {
metrics.IncStreamTunnelRequest(req.Context(), strconv.Itoa(http.StatusBadRequest))
http.Error(w, "unable to upgrade: no tunneling spdy protocols provided", http.StatusBadRequest)
return
}
@ -326,6 +330,7 @@ func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.R
if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(spdy.HeaderSpdy31)) {
klog.Errorf("unable to upgrade: missing upgrade headers in response: %#v", backendResponse.Header)
u.err = fmt.Errorf("unable to upgrade: missing upgrade headers in response")
metrics.IncStreamTunnelRequest(context.Background(), strconv.Itoa(http.StatusInternalServerError))
http.Error(u.w, u.err.Error(), http.StatusInternalServerError)
return u.err
}
@ -347,16 +352,20 @@ func (u *tunnelingWebsocketUpgraderConn) InitializeWrite(backendResponse *http.R
conn, err := upgrader.Upgrade(u.w, u.req, nil)
if err != nil {
klog.Errorf("error upgrading websocket connection: %v", err)
metrics.IncStreamTunnelRequest(context.Background(), strconv.Itoa(http.StatusInternalServerError))
u.err = err
return u.err
}
klog.V(4).Infof("websocket connection created: %s", conn.Subprotocol())
metrics.IncStreamTunnelRequest(context.Background(), strconv.Itoa(http.StatusSwitchingProtocols))
u.conn = portforward.NewTunnelingConnection("server", conn)
return nil
}
// anything other than an upgrade should pass through the backend response
klog.Errorf("SPDY upgrade failed: %s", backendResponse.Status)
metrics.IncStreamTunnelRequest(context.Background(), strconv.Itoa(backendResponse.StatusCode))
// try to hijack
conn, _, err = u.w.(http.Hijacker).Hijack()

View File

@ -40,11 +40,17 @@ import (
"k8s.io/apimachinery/pkg/util/proxy"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/apiserver/pkg/registry/rest"
"k8s.io/apiserver/pkg/util/proxy/metrics"
restconfig "k8s.io/client-go/rest"
"k8s.io/client-go/tools/portforward"
"k8s.io/component-base/metrics/legacyregistry"
"k8s.io/component-base/metrics/testutil"
)
func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) {
metrics.Register()
metrics.ResetForTest()
t.Cleanup(metrics.ResetForTest)
// Create fake upstream SPDY server, with channel receiving SPDY streams.
streamChan := make(chan httpstream.Stream)
defer close(streamChan)
@ -106,6 +112,157 @@ func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) {
t.Fatalf("timeout waiting for spdy stream to arrive on channel.")
}
assert.Equal(t, randomData, actual, "error validating tunneled random data")
// Validate the streamtunnel metrics; should be one 101 Switching Protocols.
metricNames := []string{"apiserver_stream_tunnel_requests_total"}
expected := `
# HELP apiserver_stream_tunnel_requests_total [ALPHA] Total number of requests that were handled by the StreamTunnelProxy, which processes streaming PortForward/V2
# TYPE apiserver_stream_tunnel_requests_total counter
apiserver_stream_tunnel_requests_total{code="101"} 1
`
if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, strings.NewReader(expected), metricNames...); err != nil {
t.Fatal(err)
}
}
func TestTunnelingHandler_BadRequestWithoutProtcols(t *testing.T) {
metrics.Register()
metrics.ResetForTest()
t.Cleanup(metrics.ResetForTest)
// Create TunnelingHandler with empty upstream URL and fake transport. An error should
// be returned before the upstream proxying to SPDY occurs, so a test SPDY server is not needed.
transport, err := fakeTransport()
require.NoError(t, err)
upgradeHandler := proxy.NewUpgradeAwareHandler(&url.URL{}, transport, false, true, proxy.NewErrorResponder(&fakeResponder{}))
tunnelingHandler := NewTunnelingHandler(upgradeHandler)
tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
tunnelingHandler.ServeHTTP(w, req)
}))
defer tunnelingServer.Close()
// Create SPDY client connection containing a TunnelingConnection by upgrading
// a request to TunnelingHandler using new portforward version 2.
tunnelingURL, err := url.Parse(tunnelingServer.URL)
require.NoError(t, err)
dialer, err := portforward.NewSPDYOverWebsocketDialer(tunnelingURL, &restconfig.Config{Host: tunnelingURL.Host})
require.NoError(t, err)
// Request without subprotocols--causing a bad request to be returned.
_, protocol, err := dialer.Dial("")
require.Error(t, err)
assert.Equal(t, "", protocol)
// Validate the streamtunnel metrics; should be one 400 failure.
metricNames := []string{"apiserver_stream_tunnel_requests_total"}
expected := `
# HELP apiserver_stream_tunnel_requests_total [ALPHA] Total number of requests that were handled by the StreamTunnelProxy, which processes streaming PortForward/V2
# TYPE apiserver_stream_tunnel_requests_total counter
apiserver_stream_tunnel_requests_total{code="400"} 1
`
if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, strings.NewReader(expected), metricNames...); err != nil {
t.Fatal(err)
}
}
func TestTunnelingHandler_BadHandshakeError(t *testing.T) {
metrics.Register()
metrics.ResetForTest()
t.Cleanup(metrics.ResetForTest)
// Create fake upstream SPDY server, returning forbidden for bad handshake.
spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// Handshake fails.
_, err := httpstream.Handshake(req, w, []string{constants.PortForwardV1Name})
require.Error(t, err, "handshake should have returned an error")
assert.True(t, strings.Contains(err.Error(), "unable to negotiate protocol"))
w.WriteHeader(http.StatusForbidden)
}))
defer spdyServer.Close()
// Create UpgradeAwareProxy handler, with url/transport pointing to upstream SPDY. Then
// create TunnelingHandler by injecting upgrade handler. Create TunnelingServer.
url, err := url.Parse(spdyServer.URL)
require.NoError(t, err)
transport, err := fakeTransport()
require.NoError(t, err)
upgradeHandler := proxy.NewUpgradeAwareHandler(url, transport, false, true, proxy.NewErrorResponder(&fakeResponder{}))
tunnelingHandler := NewTunnelingHandler(upgradeHandler)
tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
tunnelingHandler.ServeHTTP(w, req)
}))
defer tunnelingServer.Close()
// Create SPDY client connection containing a TunnelingConnection by upgrading
// a request to TunnelingHandler using new portforward version 2.
tunnelingURL, err := url.Parse(tunnelingServer.URL)
require.NoError(t, err)
dialer, err := portforward.NewSPDYOverWebsocketDialer(tunnelingURL, &restconfig.Config{Host: tunnelingURL.Host})
require.NoError(t, err)
// Handshake will fail, returning a 400-level response.
_, protocol, err := dialer.Dial("UNKNOWN_SUBPROTOCOL")
require.Error(t, err)
assert.Equal(t, "", protocol)
// Validate the streamtunnel metrics; should be one 400 failure.
metricNames := []string{"apiserver_stream_tunnel_requests_total"}
expected := `
# HELP apiserver_stream_tunnel_requests_total [ALPHA] Total number of requests that were handled by the StreamTunnelProxy, which processes streaming PortForward/V2
# TYPE apiserver_stream_tunnel_requests_total counter
apiserver_stream_tunnel_requests_total{code="400"} 1
`
if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, strings.NewReader(expected), metricNames...); err != nil {
t.Fatal(err)
}
}
func TestTunnelingHandler_UpstreamSPDYServerErrorPropagated(t *testing.T) {
metrics.Register()
metrics.ResetForTest()
t.Cleanup(metrics.ResetForTest)
// Validate that various 500-level errors are propagated and incremented in metrics.
for statusCode, codeStr := range map[int]string{
http.StatusInternalServerError: "500",
http.StatusBadGateway: "502",
http.StatusServiceUnavailable: "503",
} {
// Create fake upstream SPDY server, which returns a 500-level error.
spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
_, err := httpstream.Handshake(req, w, []string{constants.PortForwardV1Name})
require.NoError(t, err, "handshake should have succeeded")
// Returned status code should be incremented in metrics.
w.WriteHeader(statusCode)
}))
defer spdyServer.Close()
// Create UpgradeAwareProxy handler, with url/transport pointing to upstream SPDY. Then
// create TunnelingHandler by injecting upgrade handler. Create TunnelingServer.
url, err := url.Parse(spdyServer.URL)
require.NoError(t, err)
transport, err := fakeTransport()
require.NoError(t, err)
upgradeHandler := proxy.NewUpgradeAwareHandler(url, transport, false, true, proxy.NewErrorResponder(&fakeResponder{}))
tunnelingHandler := NewTunnelingHandler(upgradeHandler)
tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
tunnelingHandler.ServeHTTP(w, req)
}))
defer tunnelingServer.Close()
// Create SPDY client connection containing a TunnelingConnection by upgrading
// a request to TunnelingHandler using new portforward version 2.
tunnelingURL, err := url.Parse(tunnelingServer.URL)
require.NoError(t, err)
dialer, err := portforward.NewSPDYOverWebsocketDialer(tunnelingURL, &restconfig.Config{Host: tunnelingURL.Host})
require.NoError(t, err)
_, protocol, err := dialer.Dial(constants.PortForwardV1Name)
require.Error(t, err)
assert.Equal(t, "", protocol)
// Validate the streamtunnel metrics are incrementing 500-level status codes.
metricNames := []string{"apiserver_stream_tunnel_requests_total"}
expected := `
# HELP apiserver_stream_tunnel_requests_total [ALPHA] Total number of requests that were handled by the StreamTunnelProxy, which processes streaming PortForward/V2
# TYPE apiserver_stream_tunnel_requests_total counter
apiserver_stream_tunnel_requests_total{code="` + codeStr + `"} 1
`
if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, strings.NewReader(expected), metricNames...); err != nil {
t.Fatal(err)
}
metrics.ResetForTest()
}
}
func TestTunnelingResponseWriter_Hijack(t *testing.T) {