From 794e665d7b111641a96670a23f4cb8c22fac56be Mon Sep 17 00:00:00 2001 From: WanLinghao Date: Fri, 31 Aug 2018 16:04:02 +0800 Subject: [PATCH] Currently, token manager use keyFunc like: `fmt.Sprintf("%q/%q/%#v", name, namespace, tr.Spec)`. Since tr.Spec contains point fields, new token request would not reuse the cache at all. This patch fix this, also adds unit test. Signed-off-by: Mike Danese --- pkg/kubelet/token/token_manager.go | 13 +- pkg/kubelet/token/token_manager_test.go | 194 +++++++++++++++++++++++- 2 files changed, 201 insertions(+), 6 deletions(-) diff --git a/pkg/kubelet/token/token_manager.go b/pkg/kubelet/token/token_manager.go index d081bdd149d..ec755faa393 100644 --- a/pkg/kubelet/token/token_manager.go +++ b/pkg/kubelet/token/token_manager.go @@ -74,6 +74,7 @@ type Manager struct { // * If refresh fails and the old token is no longer valid, return an error func (m *Manager) GetServiceAccountToken(namespace, name string, tr *authenticationv1.TokenRequest) (*authenticationv1.TokenRequest, error) { key := keyFunc(name, namespace, tr) + ctr, ok := m.get(key) if ok && !m.requiresRefresh(ctr) { @@ -147,5 +148,15 @@ func (m *Manager) requiresRefresh(tr *authenticationv1.TokenRequest) bool { // keys should be nonconfidential and safe to log func keyFunc(name, namespace string, tr *authenticationv1.TokenRequest) string { - return fmt.Sprintf("%q/%q/%#v", name, namespace, tr.Spec) + var exp int64 + if tr.Spec.ExpirationSeconds != nil { + exp = *tr.Spec.ExpirationSeconds + } + + var ref authenticationv1.BoundObjectReference + if tr.Spec.BoundObjectRef != nil { + ref = *tr.Spec.BoundObjectRef + } + + return fmt.Sprintf("%q/%q/%#v/%#v/%#v", name, namespace, tr.Spec.Audiences, exp, ref) } diff --git a/pkg/kubelet/token/token_manager_test.go b/pkg/kubelet/token/token_manager_test.go index 2d877652dd6..2cc2766808f 100644 --- a/pkg/kubelet/token/token_manager_test.go +++ b/pkg/kubelet/token/token_manager_test.go @@ -43,7 +43,7 @@ func TestTokenCachingAndExpiration(t *testing.T) { exp: time.Hour, f: func(t *testing.T, s *suite) { s.clock.SetTime(s.clock.Now().Add(50 * time.Minute)) - if _, err := s.mgr.GetServiceAccountToken("a", "b", &authenticationv1.TokenRequest{}); err != nil { + if _, err := s.mgr.GetServiceAccountToken("a", "b", getTokenRequest()); err != nil { t.Fatalf("unexpected error: %v", err) } if s.tg.count != 2 { @@ -56,7 +56,7 @@ func TestTokenCachingAndExpiration(t *testing.T) { exp: 40 * time.Hour, f: func(t *testing.T, s *suite) { s.clock.SetTime(s.clock.Now().Add(25 * time.Hour)) - if _, err := s.mgr.GetServiceAccountToken("a", "b", &authenticationv1.TokenRequest{}); err != nil { + if _, err := s.mgr.GetServiceAccountToken("a", "b", getTokenRequest()); err != nil { t.Fatalf("unexpected error: %v", err) } if s.tg.count != 2 { @@ -73,7 +73,7 @@ func TestTokenCachingAndExpiration(t *testing.T) { err: fmt.Errorf("err"), } s.mgr.getToken = tg.getToken - tr, err := s.mgr.GetServiceAccountToken("a", "b", &authenticationv1.TokenRequest{}) + tr, err := s.mgr.GetServiceAccountToken("a", "b", getTokenRequest()) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -105,14 +105,14 @@ func TestTokenCachingAndExpiration(t *testing.T) { } s.mgr.getToken = s.tg.getToken s.mgr.clock = s.clock - if _, err := s.mgr.GetServiceAccountToken("a", "b", &authenticationv1.TokenRequest{}); err != nil { + if _, err := s.mgr.GetServiceAccountToken("a", "b", getTokenRequest()); err != nil { t.Fatalf("unexpected error: %v", err) } if s.tg.count != 1 { t.Fatalf("unexpected client call, got: %d, want: 1", s.tg.count) } - if _, err := s.mgr.GetServiceAccountToken("a", "b", &authenticationv1.TokenRequest{}); err != nil { + if _, err := s.mgr.GetServiceAccountToken("a", "b", getTokenRequest()); err != nil { t.Fatalf("unexpected error: %v", err) } if s.tg.count != 1 { @@ -221,3 +221,187 @@ func TestCleanup(t *testing.T) { }) } } + +func TestKeyFunc(t *testing.T) { + type tokenRequestUnit struct { + name string + namespace string + tr *authenticationv1.TokenRequest + } + getKeyFunc := func(u tokenRequestUnit) string { + return keyFunc(u.name, u.namespace, u.tr) + } + + cases := []struct { + name string + trus []tokenRequestUnit + target tokenRequestUnit + + shouldHit bool + }{ + { + name: "hit", + trus: []tokenRequestUnit{ + { + name: "foo-sa", + namespace: "foo-ns", + tr: &authenticationv1.TokenRequest{ + Spec: authenticationv1.TokenRequestSpec{ + Audiences: []string{"foo1", "foo2"}, + ExpirationSeconds: getInt64Point(2000), + BoundObjectRef: &authenticationv1.BoundObjectReference{ + Kind: "pod", + Name: "foo-pod", + UID: "foo-uid", + }, + }, + }, + }, + { + name: "ame-sa", + namespace: "ame-ns", + tr: &authenticationv1.TokenRequest{ + Spec: authenticationv1.TokenRequestSpec{ + Audiences: []string{"ame1", "ame2"}, + ExpirationSeconds: getInt64Point(2000), + BoundObjectRef: &authenticationv1.BoundObjectReference{ + Kind: "pod", + Name: "ame-pod", + UID: "ame-uid", + }, + }, + }, + }, + }, + target: tokenRequestUnit{ + name: "foo-sa", + namespace: "foo-ns", + tr: &authenticationv1.TokenRequest{ + Spec: authenticationv1.TokenRequestSpec{ + Audiences: []string{"foo1", "foo2"}, + ExpirationSeconds: getInt64Point(2000), + BoundObjectRef: &authenticationv1.BoundObjectReference{ + Kind: "pod", + Name: "foo-pod", + UID: "foo-uid", + }, + }, + }, + }, + shouldHit: true, + }, + { + name: "not hit due to different ExpirationSeconds", + trus: []tokenRequestUnit{ + { + name: "foo-sa", + namespace: "foo-ns", + tr: &authenticationv1.TokenRequest{ + Spec: authenticationv1.TokenRequestSpec{ + Audiences: []string{"foo1", "foo2"}, + ExpirationSeconds: getInt64Point(2000), + BoundObjectRef: &authenticationv1.BoundObjectReference{ + Kind: "pod", + Name: "foo-pod", + UID: "foo-uid", + }, + }, + }, + }, + }, + target: tokenRequestUnit{ + name: "foo-sa", + namespace: "foo-ns", + tr: &authenticationv1.TokenRequest{ + Spec: authenticationv1.TokenRequestSpec{ + Audiences: []string{"foo1", "foo2"}, + //everthing is same besides ExpirationSeconds + ExpirationSeconds: getInt64Point(2001), + BoundObjectRef: &authenticationv1.BoundObjectReference{ + Kind: "pod", + Name: "foo-pod", + UID: "foo-uid", + }, + }, + }, + }, + shouldHit: false, + }, + { + name: "not hit due to different BoundObjectRef", + trus: []tokenRequestUnit{ + { + name: "foo-sa", + namespace: "foo-ns", + tr: &authenticationv1.TokenRequest{ + Spec: authenticationv1.TokenRequestSpec{ + Audiences: []string{"foo1", "foo2"}, + ExpirationSeconds: getInt64Point(2000), + BoundObjectRef: &authenticationv1.BoundObjectReference{ + Kind: "pod", + Name: "foo-pod", + UID: "foo-uid", + }, + }, + }, + }, + }, + target: tokenRequestUnit{ + name: "foo-sa", + namespace: "foo-ns", + tr: &authenticationv1.TokenRequest{ + Spec: authenticationv1.TokenRequestSpec{ + Audiences: []string{"foo1", "foo2"}, + ExpirationSeconds: getInt64Point(2000), + BoundObjectRef: &authenticationv1.BoundObjectReference{ + Kind: "pod", + //everthing is same besides BoundObjectRef.Name + Name: "diff-pod", + UID: "foo-uid", + }, + }, + }, + }, + shouldHit: false, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + mgr := NewManager(nil) + mgr.clock = clock.NewFakeClock(time.Time{}.Add(30 * 24 * time.Hour)) + for _, tru := range c.trus { + mgr.set(getKeyFunc(tru), &authenticationv1.TokenRequest{ + Status: authenticationv1.TokenRequestStatus{ + //make sure the token cache would not be cleaned by token manager clenaup func + ExpirationTimestamp: metav1.Time{Time: mgr.clock.Now().Add(50 * time.Minute)}, + }, + }) + } + _, hit := mgr.get(getKeyFunc(c.target)) + + if hit != c.shouldHit { + t.Errorf("%s got unexpected hit result: expected to be %t, got %t", c.name, c.shouldHit, hit) + } + }) + } + +} + +func getTokenRequest() *authenticationv1.TokenRequest { + return &authenticationv1.TokenRequest{ + Spec: authenticationv1.TokenRequestSpec{ + Audiences: []string{"foo1", "foo2"}, + ExpirationSeconds: getInt64Point(2000), + BoundObjectRef: &authenticationv1.BoundObjectReference{ + Kind: "pod", + Name: "foo-pod", + UID: "foo-uid", + }, + }, + } +} + +func getInt64Point(v int64) *int64 { + return &v +}