add explicit typing for continue tests

Our tests are mostly error based and explicit error typing allows
us to test against error types directly. Having made this change also
makes it obvious that our test coverage was lacking in two branches,
specifically, we were previously not testing empty start keys nor were
we testing for invalid start RVs.
This commit is contained in:
Han Kang 2022-05-31 10:23:07 -07:00
parent 135ac17f20
commit 213e380a2e
2 changed files with 69 additions and 18 deletions

View File

@ -19,11 +19,19 @@ package storage
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"path"
"strings"
)
var (
ErrInvalidStartRV = errors.New("continue key is not valid: incorrect encoded start resourceVersion (version meta.k8s.io/v1)")
ErrEmptyStartKey = errors.New("continue key is not valid: encoded start key empty (version meta.k8s.io/v1)")
ErrGenericInvalidKey = errors.New("continue key is not valid")
ErrUnrecognizedEncodedVersion = errors.New("continue key is not valid: server does not recognize this encoded version")
)
// continueToken is a simple structured object for encoding the state of a continue token.
// TODO: if we change the version of the encoded from, we can't start encoding the new version
// until all other servers are upgraded (i.e. we need to support rolling schema)
@ -39,19 +47,19 @@ type continueToken struct {
func DecodeContinue(continueValue, keyPrefix string) (fromKey string, rv int64, err error) {
data, err := base64.RawURLEncoding.DecodeString(continueValue)
if err != nil {
return "", 0, fmt.Errorf("continue key is not valid: %v", err)
return "", 0, fmt.Errorf("%w: %v", ErrGenericInvalidKey, err)
}
var c continueToken
if err := json.Unmarshal(data, &c); err != nil {
return "", 0, fmt.Errorf("continue key is not valid: %v", err)
return "", 0, fmt.Errorf("%w: %v", ErrGenericInvalidKey, err)
}
switch c.APIVersion {
case "meta.k8s.io/v1":
if c.ResourceVersion == 0 {
return "", 0, fmt.Errorf("continue key is not valid: incorrect encoded start resourceVersion (version meta.k8s.io/v1)")
return "", 0, ErrInvalidStartRV
}
if len(c.StartKey) == 0 {
return "", 0, fmt.Errorf("continue key is not valid: encoded start key empty (version meta.k8s.io/v1)")
return "", 0, ErrEmptyStartKey
}
// defend against path traversal attacks by clients - path.Clean will ensure that startKey cannot
// be at a higher level of the hierarchy, and so when we append the key prefix we will end up with
@ -63,11 +71,11 @@ func DecodeContinue(continueValue, keyPrefix string) (fromKey string, rv int64,
}
cleaned := path.Clean(key)
if cleaned != key {
return "", 0, fmt.Errorf("continue key is not valid: %s", c.StartKey)
return "", 0, fmt.Errorf("%w: %v", ErrGenericInvalidKey, c.StartKey)
}
return keyPrefix + cleaned[1:], c.ResourceVersion, nil
default:
return "", 0, fmt.Errorf("continue key is not valid: server does not recognize this encoded version %q", c.APIVersion)
return "", 0, fmt.Errorf("%w %v", ErrUnrecognizedEncodedVersion, c.APIVersion)
}
}

View File

@ -19,6 +19,7 @@ package storage
import (
"encoding/base64"
"encoding/json"
"errors"
"testing"
)
@ -40,23 +41,65 @@ func Test_decodeContinue(t *testing.T) {
args args
wantFromKey string
wantRv int64
wantErr bool
wantErr error
}{
{name: "valid", args: args{continueValue: encodeContinueOrDie("meta.k8s.io/v1", 1, "key"), keyPrefix: "/test/"}, wantRv: 1, wantFromKey: "/test/key"},
{name: "root path", args: args{continueValue: encodeContinueOrDie("meta.k8s.io/v1", 1, "/"), keyPrefix: "/test/"}, wantRv: 1, wantFromKey: "/test/"},
{name: "empty version", args: args{continueValue: encodeContinueOrDie("", 1, "key"), keyPrefix: "/test/"}, wantErr: true},
{name: "invalid version", args: args{continueValue: encodeContinueOrDie("v1", 1, "key"), keyPrefix: "/test/"}, wantErr: true},
{name: "path traversal - parent", args: args{continueValue: encodeContinueOrDie("meta.k8s.io/v1", 1, "../key"), keyPrefix: "/test/"}, wantErr: true},
{name: "path traversal - local", args: args{continueValue: encodeContinueOrDie("meta.k8s.io/v1", 1, "./key"), keyPrefix: "/test/"}, wantErr: true},
{name: "path traversal - double parent", args: args{continueValue: encodeContinueOrDie("meta.k8s.io/v1", 1, "./../key"), keyPrefix: "/test/"}, wantErr: true},
{name: "path traversal - after parent", args: args{continueValue: encodeContinueOrDie("meta.k8s.io/v1", 1, "key/../.."), keyPrefix: "/test/"}, wantErr: true},
{
name: "valid",
args: args{continueValue: encodeContinueOrDie("meta.k8s.io/v1", 1, "key"), keyPrefix: "/test/"},
wantRv: 1,
wantFromKey: "/test/key",
},
{
name: "root path",
args: args{continueValue: encodeContinueOrDie("meta.k8s.io/v1", 1, "/"), keyPrefix: "/test/"},
wantRv: 1,
wantFromKey: "/test/",
},
{
name: "empty version",
args: args{continueValue: encodeContinueOrDie("", 1, "key"), keyPrefix: "/test/"},
wantErr: ErrUnrecognizedEncodedVersion,
},
{
name: "invalid version",
args: args{continueValue: encodeContinueOrDie("v1", 1, "key"), keyPrefix: "/test/"},
wantErr: ErrUnrecognizedEncodedVersion,
},
{
name: "invalid RV",
args: args{continueValue: encodeContinueOrDie("meta.k8s.io/v1", 0, "key"), keyPrefix: "/test/"},
wantErr: ErrInvalidStartRV,
},
{
name: "no start Key",
args: args{continueValue: encodeContinueOrDie("meta.k8s.io/v1", 1, ""), keyPrefix: "/test/"},
wantErr: ErrEmptyStartKey,
},
{
name: "path traversal - parent",
args: args{continueValue: encodeContinueOrDie("meta.k8s.io/v1", 1, "../key"), keyPrefix: "/test/"},
wantErr: ErrGenericInvalidKey,
},
{
name: "path traversal - local",
args: args{continueValue: encodeContinueOrDie("meta.k8s.io/v1", 1, "./key"), keyPrefix: "/test/"},
wantErr: ErrGenericInvalidKey,
},
{
name: "path traversal - double parent",
args: args{continueValue: encodeContinueOrDie("meta.k8s.io/v1", 1, "./../key"), keyPrefix: "/test/"},
wantErr: ErrGenericInvalidKey,
},
{
name: "path traversal - after parent",
args: args{continueValue: encodeContinueOrDie("meta.k8s.io/v1", 1, "key/../.."), keyPrefix: "/test/"},
wantErr: ErrGenericInvalidKey,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotFromKey, gotRv, err := DecodeContinue(tt.args.continueValue, tt.args.keyPrefix)
if (err != nil) != tt.wantErr {
if !errors.Is(err, tt.wantErr) {
t.Errorf("decodeContinue() error = %v, wantErr %v", err, tt.wantErr)
return
}