diff --git a/discovery/cached/disk/round_tripper.go b/discovery/cached/disk/round_tripper.go index bda2e5cf..f3a4b294 100644 --- a/discovery/cached/disk/round_tripper.go +++ b/discovery/cached/disk/round_tripper.go @@ -17,12 +17,14 @@ limitations under the License. package disk import ( + "bytes" + "crypto/sha256" + "fmt" "net/http" "os" "path/filepath" "github.com/gregjones/httpcache" - "github.com/gregjones/httpcache/diskcache" "github.com/peterbourgon/diskv" "k8s.io/klog/v2" ) @@ -41,7 +43,7 @@ func newCacheRoundTripper(cacheDir string, rt http.RoundTripper) http.RoundTripp BasePath: cacheDir, TempDir: filepath.Join(cacheDir, ".diskv-temp"), }) - t := httpcache.NewTransport(diskcache.NewWithDiskv(d)) + t := httpcache.NewTransport(&sumDiskCache{disk: d}) t.Transport = rt return &cacheRoundTripper{rt: t} @@ -63,3 +65,56 @@ func (rt *cacheRoundTripper) CancelRequest(req *http.Request) { } func (rt *cacheRoundTripper) WrappedRoundTripper() http.RoundTripper { return rt.rt.Transport } + +// A sumDiskCache is a cache backend for github.com/gregjones/httpcache. It is +// similar to httpcache's diskcache package, but uses SHA256 sums to ensure +// cache integrity at read time rather than fsyncing each cache entry to +// increase the likelihood they will be persisted at write time. This avoids +// significant performance degradation on MacOS. +// +// See https://github.com/kubernetes/kubernetes/issues/110753 for more. +type sumDiskCache struct { + disk *diskv.Diskv +} + +// Get the requested key from the cache on disk. If Get encounters an error, or +// the returned value is not a SHA256 sum followed by bytes with a matching +// checksum it will return false to indicate a cache miss. +func (c *sumDiskCache) Get(key string) ([]byte, bool) { + b, err := c.disk.Read(sanitize(key)) + if err != nil || len(b) < sha256.Size { + return []byte{}, false + } + + response := b[sha256.Size:] + want := b[:sha256.Size] // The first 32 bytes of the file should be the SHA256 sum. + got := sha256.Sum256(response) + if !bytes.Equal(want, got[:]) { + return []byte{}, false + } + + return response, true +} + +// Set writes the response to a file on disk. The filename will be the SHA256 +// sum of the key. The file will contain a SHA256 sum of the response bytes, +// followed by said response bytes. +func (c *sumDiskCache) Set(key string, response []byte) { + s := sha256.Sum256(response) + _ = c.disk.Write(sanitize(key), append(s[:], response...)) // Nothing we can do with this error. +} + +func (c *sumDiskCache) Delete(key string) { + _ = c.disk.Erase(sanitize(key)) // Nothing we can do with this error. +} + +// Sanitize an httpcache key such that it can be used as a diskv key, which must +// be a valid filename. The httpcache key will either be the requested URL (if +// the request method was GET) or " " for other methods, per the +// httpcache.cacheKey function. +func sanitize(key string) string { + // These keys are not sensitive. We use sha256 to avoid a (potentially + // malicious) collision causing the wrong cache data to be written or + // accessed. + return fmt.Sprintf("%x", sha256.Sum256([]byte(key))) +} diff --git a/discovery/cached/disk/round_tripper_test.go b/discovery/cached/disk/round_tripper_test.go index 13002c63..5f1626c9 100644 --- a/discovery/cached/disk/round_tripper_test.go +++ b/discovery/cached/disk/round_tripper_test.go @@ -18,6 +18,7 @@ package disk import ( "bytes" + "crypto/sha256" "io/ioutil" "net/http" "net/url" @@ -25,6 +26,7 @@ import ( "path/filepath" "testing" + "github.com/peterbourgon/diskv" "github.com/stretchr/testify/assert" ) @@ -40,6 +42,35 @@ func (rt *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) return rt.Response, rt.Err } +func BenchmarkDiskCache(b *testing.B) { + cacheDir, err := ioutil.TempDir("", "cache-rt") + if err != nil { + b.Fatal(err) + } + defer os.RemoveAll(cacheDir) + + d := diskv.New(diskv.Options{ + PathPerm: os.FileMode(0750), + FilePerm: os.FileMode(0660), + BasePath: cacheDir, + TempDir: filepath.Join(cacheDir, ".diskv-temp"), + }) + + k := "localhost:8080/apis/batch/v1.json" + v, err := ioutil.ReadFile("../../testdata/apis/batch/v1.json") + if err != nil { + b.Fatal(err) + } + + c := sumDiskCache{disk: d} + + for n := 0; n < b.N; n++ { + c.Set(k, v) + c.Get(k) + c.Delete(k) + } +} + func TestCacheRoundTripper(t *testing.T) { rt := &testRoundTripper{} cacheDir, err := ioutil.TempDir("", "cache-rt") @@ -145,3 +176,146 @@ func TestCacheRoundTripperPathPerm(t *testing.T) { }) assert.NoError(err) } + +func TestSumDiskCache(t *testing.T) { + assert := assert.New(t) + + // Ensure that we'll return a cache miss if the backing file doesn't exist. + t.Run("NoSuchKey", func(t *testing.T) { + cacheDir, err := ioutil.TempDir("", "cache-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(cacheDir) + d := diskv.New(diskv.Options{BasePath: cacheDir, TempDir: filepath.Join(cacheDir, ".diskv-temp")}) + c := &sumDiskCache{disk: d} + + key := "testing" + + got, ok := c.Get(key) + assert.False(ok) + assert.Equal([]byte{}, got) + }) + + // Ensure that we'll return a cache miss if the backing file is empty. + t.Run("EmptyFile", func(t *testing.T) { + cacheDir, err := ioutil.TempDir("", "cache-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(cacheDir) + d := diskv.New(diskv.Options{BasePath: cacheDir, TempDir: filepath.Join(cacheDir, ".diskv-temp")}) + c := &sumDiskCache{disk: d} + + key := "testing" + + f, err := os.Create(filepath.Join(cacheDir, sanitize(key))) + if err != nil { + t.Fatal(err) + } + f.Close() + + got, ok := c.Get(key) + assert.False(ok) + assert.Equal([]byte{}, got) + }) + + // Ensure that we'll return a cache miss if the backing has an invalid + // checksum. + t.Run("InvalidChecksum", func(t *testing.T) { + cacheDir, err := ioutil.TempDir("", "cache-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(cacheDir) + d := diskv.New(diskv.Options{BasePath: cacheDir, TempDir: filepath.Join(cacheDir, ".diskv-temp")}) + c := &sumDiskCache{disk: d} + + key := "testing" + value := []byte("testing") + mismatchedValue := []byte("testink") + sum := sha256.Sum256(value) + + // Create a file with the sum of 'value' followed by the bytes of + // 'mismatchedValue'. + f, err := os.Create(filepath.Join(cacheDir, sanitize(key))) + if err != nil { + t.Fatal(err) + } + f.Write(sum[:]) + f.Write(mismatchedValue) + f.Close() + + // The mismatched checksum should result in a cache miss. + got, ok := c.Get(key) + assert.False(ok) + assert.Equal([]byte{}, got) + }) + + // Ensure that our disk cache will happily cache over the top of an existing + // value. We depend on this behaviour to recover from corrupted cache + // entries. When Get detects a bad checksum it will return a cache miss. + // This should cause httpcache to fall back to its underlying transport and + // to subsequently cache the new value, overwriting the corrupt one. + t.Run("OverwriteExistingKey", func(t *testing.T) { + cacheDir, err := ioutil.TempDir("", "cache-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(cacheDir) + d := diskv.New(diskv.Options{BasePath: cacheDir, TempDir: filepath.Join(cacheDir, ".diskv-temp")}) + c := &sumDiskCache{disk: d} + + key := "testing" + value := []byte("cool value!") + + // Write a value. + c.Set(key, value) + got, ok := c.Get(key) + + // Ensure we can read back what we wrote. + assert.True(ok) + assert.Equal(value, got) + + differentValue := []byte("I'm different!") + + // Write a different value. + c.Set(key, differentValue) + got, ok = c.Get(key) + + // Ensure we can read back the different value. + assert.True(ok) + assert.Equal(differentValue, got) + }) + + // Ensure that deleting a key does in fact delete it. + t.Run("DeleteKey", func(t *testing.T) { + cacheDir, err := ioutil.TempDir("", "cache-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(cacheDir) + d := diskv.New(diskv.Options{BasePath: cacheDir, TempDir: filepath.Join(cacheDir, ".diskv-temp")}) + c := &sumDiskCache{disk: d} + + key := "testing" + value := []byte("coolValue") + + c.Set(key, value) + + // Ensure we successfully set the value. + got, ok := c.Get(key) + assert.True(ok) + assert.Equal(value, got) + + c.Delete(key) + + // Ensure the value is gone. + got, ok = c.Get(key) + assert.False(ok) + assert.Equal([]byte{}, got) + + // Ensure that deleting a non-existent value is a no-op. + c.Delete(key) + }) +} diff --git a/go.mod b/go.mod index 45bc47f8..e5cd0cb3 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,7 @@ require ( golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 google.golang.org/protobuf v1.28.0 k8s.io/api v0.0.0-20220727200302-537ea12bb18b - k8s.io/apimachinery v0.0.0-20220727200059-47ba8cbe2b8f + k8s.io/apimachinery v0.0.0-20220729201108-d58901cae3e7 k8s.io/klog/v2 v2.70.1 k8s.io/kube-openapi v0.0.0-20220627174259-011e075b9cb8 k8s.io/utils v0.0.0-20220725171434-9bab9ef40391 @@ -62,5 +62,5 @@ require ( replace ( k8s.io/api => k8s.io/api v0.0.0-20220727200302-537ea12bb18b - k8s.io/apimachinery => k8s.io/apimachinery v0.0.0-20220727200059-47ba8cbe2b8f + k8s.io/apimachinery => k8s.io/apimachinery v0.0.0-20220729201108-d58901cae3e7 ) diff --git a/go.sum b/go.sum index c757f080..20e52a02 100644 --- a/go.sum +++ b/go.sum @@ -481,8 +481,8 @@ honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9 honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= k8s.io/api v0.0.0-20220727200302-537ea12bb18b h1:bWhGCDmlF51DNpRAl3k6oyzFRmp/Qc8B16VsJ6Pguok= k8s.io/api v0.0.0-20220727200302-537ea12bb18b/go.mod h1:xKJDb77GsCnm/CFi8Dngm7nnP+Ob6/i8bvXy2hi1s1w= -k8s.io/apimachinery v0.0.0-20220727200059-47ba8cbe2b8f h1:m3zoqZrZtUTBO1MH7OcsQK9hwwekSeU7Us4clrmHYaI= -k8s.io/apimachinery v0.0.0-20220727200059-47ba8cbe2b8f/go.mod h1:SruqPXeym/+E0MDJj3s3ymS3KjqcosO7UaWC8HCOz2w= +k8s.io/apimachinery v0.0.0-20220729201108-d58901cae3e7 h1:wmWVhvmGAk3Bd41nDh/iuu9ASP1UnNjX0pXfUI6zyV4= +k8s.io/apimachinery v0.0.0-20220729201108-d58901cae3e7/go.mod h1:SruqPXeym/+E0MDJj3s3ymS3KjqcosO7UaWC8HCOz2w= k8s.io/klog/v2 v2.0.0/go.mod h1:PBfzABfn139FHAV07az/IF9Wp1bkk3vpT2XSJ76fSDE= k8s.io/klog/v2 v2.70.1 h1:7aaoSdahviPmR+XkS7FyxlkkXs6tHISSG03RxleQAVQ= k8s.io/klog/v2 v2.70.1/go.mod h1:y1WjHnz7Dj687irZUWR/WLkLc5N1YHtjLdmgWjndZn0=