diff --git a/pkg/sqlcache/db/client.go b/pkg/sqlcache/db/client.go index 2259640d..94dbd3be 100644 --- a/pkg/sqlcache/db/client.go +++ b/pkg/sqlcache/db/client.go @@ -49,6 +49,8 @@ type Client interface { Upsert(tx transaction.Client, stmt *sql.Stmt, key string, obj any, shouldEncrypt bool) error CloseStmt(closable Closable) error NewConnection(isTemp bool) (string, error) + Encryptor() Encryptor + Decryptor() Decryptor } // WithTransaction runs f within a transaction. @@ -364,6 +366,14 @@ func (c *client) Upsert(tx transaction.Client, stmt *sql.Stmt, key string, obj a return err } +func (c *client) Encryptor() Encryptor { + return c.encryptor +} + +func (c *client) Decryptor() Decryptor { + return c.decryptor +} + // toBytes encodes an object to a byte slice func toBytes(obj any) []byte { var buf bytes.Buffer diff --git a/pkg/sqlcache/informer/db_mocks_test.go b/pkg/sqlcache/informer/db_mocks_test.go index 702dda4d..1b900090 100644 --- a/pkg/sqlcache/informer/db_mocks_test.go +++ b/pkg/sqlcache/informer/db_mocks_test.go @@ -141,6 +141,34 @@ func (mr *MockClientMockRecorder) CloseStmt(closable any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockClient)(nil).CloseStmt), closable) } +// Decryptor mocks base method. +func (m *MockClient) Decryptor() db.Decryptor { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Decryptor") + ret0, _ := ret[0].(db.Decryptor) + return ret0 +} + +// Decryptor indicates an expected call of Decryptor. +func (mr *MockClientMockRecorder) Decryptor() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decryptor", reflect.TypeOf((*MockClient)(nil).Decryptor)) +} + +// Encryptor mocks base method. +func (m *MockClient) Encryptor() db.Encryptor { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Encryptor") + ret0, _ := ret[0].(db.Encryptor) + return ret0 +} + +// Encryptor indicates an expected call of Encryptor. +func (mr *MockClientMockRecorder) Encryptor() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encryptor", reflect.TypeOf((*MockClient)(nil).Encryptor)) +} + // NewConnection mocks base method. func (m *MockClient) NewConnection(isTemp bool) (string, error) { m.ctrl.T.Helper() diff --git a/pkg/sqlcache/informer/factory/db_mocks_test.go b/pkg/sqlcache/informer/factory/db_mocks_test.go index 0fae0191..7a6864f6 100644 --- a/pkg/sqlcache/informer/factory/db_mocks_test.go +++ b/pkg/sqlcache/informer/factory/db_mocks_test.go @@ -57,6 +57,34 @@ func (mr *MockClientMockRecorder) CloseStmt(closable any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockClient)(nil).CloseStmt), closable) } +// Decryptor mocks base method. +func (m *MockClient) Decryptor() db.Decryptor { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Decryptor") + ret0, _ := ret[0].(db.Decryptor) + return ret0 +} + +// Decryptor indicates an expected call of Decryptor. +func (mr *MockClientMockRecorder) Decryptor() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decryptor", reflect.TypeOf((*MockClient)(nil).Decryptor)) +} + +// Encryptor mocks base method. +func (m *MockClient) Encryptor() db.Encryptor { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Encryptor") + ret0, _ := ret[0].(db.Encryptor) + return ret0 +} + +// Encryptor indicates an expected call of Encryptor. +func (mr *MockClientMockRecorder) Encryptor() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encryptor", reflect.TypeOf((*MockClient)(nil).Encryptor)) +} + // NewConnection mocks base method. func (m *MockClient) NewConnection(isTemp bool) (string, error) { m.ctrl.T.Helper() diff --git a/pkg/sqlcache/informer/listoption_indexer.go b/pkg/sqlcache/informer/listoption_indexer.go index a2abeb32..7a268d1a 100644 --- a/pkg/sqlcache/informer/listoption_indexer.go +++ b/pkg/sqlcache/informer/listoption_indexer.go @@ -87,9 +87,11 @@ const ( rv TEXT NOT NULL, type TEXT NOT NULL, event BLOB NOT NULL, + eventnonce BLOB, + dekid BLOB, PRIMARY KEY (type, rv) )` - listEventsAfterFmt = `SELECT type, rv, event + listEventsAfterFmt = `SELECT type, rv, event, eventnonce, dekid FROM "%s_events" WHERE rowid > ? ` @@ -243,7 +245,7 @@ func NewListOptionIndexer(ctx context.Context, s Store, opts ListOptionIndexerOp } l.upsertEventsQuery = fmt.Sprintf( - `REPLACE INTO "%s_events"(rv, type, event) VALUES (?, ?, ?)`, + `REPLACE INTO "%s_events"(rv, type, event, eventnonce, dekid) VALUES (?, ?, ?, ?, ?)`, dbName, ) l.upsertEventsStmt = l.Prepare(l.upsertEventsQuery) @@ -321,9 +323,7 @@ func (l *ListOptionIndexer) Watch(ctx context.Context, opts WatchOptions, events defer rows.Close() for rows.Next() { - var typ, rv string - var buf sql.RawBytes - err := rows.Scan(&typ, &rv, &buf) + typ, buf, err := l.decryptScanEvent(rows) if err != nil { return fmt.Errorf("scanning event row: %w", err) } @@ -370,6 +370,24 @@ func (l *ListOptionIndexer) Watch(ctx context.Context, opts WatchOptions, events return nil } +func (l *ListOptionIndexer) decryptScanEvent(rows db.Rows) (watch.EventType, []byte, error) { + var typ, rv string + var event, eventNonce sql.RawBytes + var kid uint32 + err := rows.Scan(&typ, &rv, &event, &eventNonce, &kid) + if err != nil { + return watch.Error, nil, err + } + if l.Decryptor() != nil && l.GetShouldEncrypt() { + decryptedData, err := l.Decryptor().Decrypt(event, eventNonce, kid) + if err != nil { + return watch.Error, nil, err + } + return watch.EventType(typ), decryptedData, nil + } + return watch.EventType(typ), event, nil +} + func toBytes(obj any) []byte { var buf bytes.Buffer enc := gob.NewEncoder(&buf) @@ -452,9 +470,10 @@ func (l *ListOptionIndexer) notifyEvent(eventType watch.EventType, oldObj any, o } latestRV := acc.GetResourceVersion() - _, err = tx.Stmt(l.upsertEventsStmt).Exec(latestRV, eventType, toBytes(obj)) + + err = l.upsertEvent(tx, eventType, latestRV, obj) if err != nil { - return &db.QueryError{QueryString: l.upsertEventsQuery, Err: err} + return err } l.watchersLock.RLock() @@ -476,6 +495,26 @@ func (l *ListOptionIndexer) notifyEvent(eventType watch.EventType, oldObj any, o return nil } +func (l *ListOptionIndexer) upsertEvent(tx transaction.Client, eventType watch.EventType, latestRV string, obj any) error { + objBytes := toBytes(obj) + var dataNonce []byte + var err error + var kid uint32 + if l.Encryptor() != nil && l.GetShouldEncrypt() { + objBytes, dataNonce, kid, err = l.Encryptor().Encrypt(objBytes) + if err != nil { + return err + } + } + + _, err = tx.Stmt(l.upsertEventsStmt).Exec(latestRV, eventType, objBytes, dataNonce, kid) + if err != nil { + return &db.QueryError{QueryString: l.upsertEventsQuery, Err: err} + } + + return err +} + // addIndexFields saves sortable/filterable fields into tables func (l *ListOptionIndexer) addIndexFields(key string, obj any, tx transaction.Client) error { args := []any{key} diff --git a/pkg/sqlcache/informer/listoption_indexer_test.go b/pkg/sqlcache/informer/listoption_indexer_test.go index bb5156a9..f7f7725e 100644 --- a/pkg/sqlcache/informer/listoption_indexer_test.go +++ b/pkg/sqlcache/informer/listoption_indexer_test.go @@ -31,7 +31,7 @@ import ( "k8s.io/client-go/tools/cache" ) -func makeListOptionIndexer(ctx context.Context, opts ListOptionIndexerOptions) (*ListOptionIndexer, string, error) { +func makeListOptionIndexer(ctx context.Context, opts ListOptionIndexerOptions, shouldEncrypt bool) (*ListOptionIndexer, string, error) { gvk := schema.GroupVersionKind{ Group: "", Version: "v1", @@ -50,7 +50,7 @@ func makeListOptionIndexer(ctx context.Context, opts ListOptionIndexerOptions) ( return nil, "", err } - s, err := store.NewStore(ctx, example, cache.DeletionHandlingMetaNamespaceKeyFunc, db, false, gvk, name, nil, nil) + s, err := store.NewStore(ctx, example, cache.DeletionHandlingMetaNamespaceKeyFunc, db, shouldEncrypt, gvk, name, nil, nil) if err != nil { return nil, "", err } @@ -1008,7 +1008,7 @@ func TestNewListOptionIndexerEasy(t *testing.T) { Fields: fields, IsNamespaced: true, } - loi, dbPath, err := makeListOptionIndexer(ctx, opts) + loi, dbPath, err := makeListOptionIndexer(ctx, opts, false) defer cleanTempFiles(dbPath) assert.NoError(t, err) @@ -1218,7 +1218,7 @@ func TestUserDefinedExtractFunction(t *testing.T) { Fields: fields, IsNamespaced: true, } - loi, dbPath, err := makeListOptionIndexer(ctx, opts) + loi, dbPath, err := makeListOptionIndexer(ctx, opts, false) defer cleanTempFiles(dbPath) assert.NoError(t, err) @@ -2234,6 +2234,93 @@ func TestGetField(t *testing.T) { } } +func TestWatchEncryption(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + opts := ListOptionIndexerOptions{ + Fields: [][]string{ + {"metadata", "somefield"}, + {"spec", "replicas"}, + {"spec", "minReplicas"}, + }, + IsNamespaced: true, + } + // shouldEncrypt = true to ensure we can write + read from encrypted events + loi, dbPath, err := makeListOptionIndexer(ctx, opts, true) + defer cleanTempFiles(dbPath) + assert.NoError(t, err) + + foo := &unstructured.Unstructured{ + Object: map[string]any{ + "metadata": map[string]any{ + "name": "foo", + }, + "spec": map[string]any{ + "replicas": int64(1), + }, + }, + } + foo.SetResourceVersion("100") + foo2 := foo.DeepCopy() + foo2.SetResourceVersion("120") + + startWatcher := func(ctx context.Context) (chan watch.Event, chan error) { + errCh := make(chan error, 1) + eventsCh := make(chan watch.Event, 100) + go func() { + watchErr := loi.Watch(ctx, WatchOptions{ + // Make a watch request to this specific resource version to be sure we go get from SQL database + ResourceVersion: "100", + }, eventsCh) + errCh <- watchErr + }() + time.Sleep(100 * time.Millisecond) + return eventsCh, errCh + } + + waitStopWatcher := func(errCh chan error) error { + select { + case <-time.After(time.Second * 5): + return fmt.Errorf("not finished in time") + case err := <-errCh: + return err + } + } + + receiveEvents := func(eventsCh chan watch.Event) []watch.Event { + timer := time.NewTimer(time.Millisecond * 50) + var events []watch.Event + for { + select { + case <-timer.C: + return events + case ev := <-eventsCh: + events = append(events, ev) + } + } + } + + err = loi.Add(foo) + assert.NoError(t, err) + err = loi.Update(foo2) + assert.NoError(t, err) + + watcher1, errCh1 := startWatcher(ctx) + events := receiveEvents(watcher1) + assert.Len(t, events, 1) + assert.Equal(t, []watch.Event{ + { + Type: watch.Modified, + Object: foo2, + }, + }, events) + + cancel() + + err = waitStopWatcher(errCh1) + assert.NoError(t, err) +} + func TestWatchMany(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) @@ -2245,7 +2332,7 @@ func TestWatchMany(t *testing.T) { }, IsNamespaced: true, } - loi, dbPath, err := makeListOptionIndexer(ctx, opts) + loi, dbPath, err := makeListOptionIndexer(ctx, opts, false) defer cleanTempFiles(dbPath) assert.NoError(t, err) @@ -2502,7 +2589,7 @@ func TestWatchFilter(t *testing.T) { Fields: [][]string{{"metadata", "somefield"}}, IsNamespaced: true, } - loi, dbPath, err := makeListOptionIndexer(ctx, opts) + loi, dbPath, err := makeListOptionIndexer(ctx, opts, false) defer cleanTempFiles(dbPath) assert.NoError(t, err) @@ -2594,7 +2681,7 @@ func TestWatchResourceVersion(t *testing.T) { opts := ListOptionIndexerOptions{ IsNamespaced: true, } - loi, dbPath, err := makeListOptionIndexer(parentCtx, opts) + loi, dbPath, err := makeListOptionIndexer(parentCtx, opts, false) defer cleanTempFiles(dbPath) assert.NoError(t, err) @@ -2748,7 +2835,7 @@ func TestWatchGarbageCollection(t *testing.T) { GCInterval: 40 * time.Millisecond, GCKeepCount: 2, } - loi, dbPath, err := makeListOptionIndexer(parentCtx, opts) + loi, dbPath, err := makeListOptionIndexer(parentCtx, opts, false) defer cleanTempFiles(dbPath) assert.NoError(t, err) @@ -2859,7 +2946,7 @@ func TestNonNumberResourceVersion(t *testing.T) { Fields: [][]string{{"metadata", "somefield"}}, IsNamespaced: true, } - loi, dbPath, err := makeListOptionIndexer(ctx, opts) + loi, dbPath, err := makeListOptionIndexer(ctx, opts, false) defer cleanTempFiles(dbPath) assert.NoError(t, err) diff --git a/pkg/sqlcache/informer/sql_mocks_test.go b/pkg/sqlcache/informer/sql_mocks_test.go index 70016975..4f92758a 100644 --- a/pkg/sqlcache/informer/sql_mocks_test.go +++ b/pkg/sqlcache/informer/sql_mocks_test.go @@ -71,6 +71,20 @@ func (mr *MockStoreMockRecorder) CloseStmt(closable any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockStore)(nil).CloseStmt), closable) } +// Decryptor mocks base method. +func (m *MockStore) Decryptor() db.Decryptor { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Decryptor") + ret0, _ := ret[0].(db.Decryptor) + return ret0 +} + +// Decryptor indicates an expected call of Decryptor. +func (mr *MockStoreMockRecorder) Decryptor() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decryptor", reflect.TypeOf((*MockStore)(nil).Decryptor)) +} + // Delete mocks base method. func (m *MockStore) Delete(obj any) error { m.ctrl.T.Helper() @@ -85,6 +99,20 @@ func (mr *MockStoreMockRecorder) Delete(obj any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockStore)(nil).Delete), obj) } +// Encryptor mocks base method. +func (m *MockStore) Encryptor() db.Encryptor { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Encryptor") + ret0, _ := ret[0].(db.Encryptor) + return ret0 +} + +// Encryptor indicates an expected call of Encryptor. +func (mr *MockStoreMockRecorder) Encryptor() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encryptor", reflect.TypeOf((*MockStore)(nil).Encryptor)) +} + // Get mocks base method. func (m *MockStore) Get(obj any) (any, bool, error) { m.ctrl.T.Helper() diff --git a/pkg/sqlcache/store/db_mocks_test.go b/pkg/sqlcache/store/db_mocks_test.go index c6ed7b07..67c1c975 100644 --- a/pkg/sqlcache/store/db_mocks_test.go +++ b/pkg/sqlcache/store/db_mocks_test.go @@ -141,6 +141,34 @@ func (mr *MockClientMockRecorder) CloseStmt(closable any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockClient)(nil).CloseStmt), closable) } +// Decryptor mocks base method. +func (m *MockClient) Decryptor() db.Decryptor { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Decryptor") + ret0, _ := ret[0].(db.Decryptor) + return ret0 +} + +// Decryptor indicates an expected call of Decryptor. +func (mr *MockClientMockRecorder) Decryptor() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decryptor", reflect.TypeOf((*MockClient)(nil).Decryptor)) +} + +// Encryptor mocks base method. +func (m *MockClient) Encryptor() db.Encryptor { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Encryptor") + ret0, _ := ret[0].(db.Encryptor) + return ret0 +} + +// Encryptor indicates an expected call of Encryptor. +func (mr *MockClientMockRecorder) Encryptor() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encryptor", reflect.TypeOf((*MockClient)(nil).Encryptor)) +} + // NewConnection mocks base method. func (m *MockClient) NewConnection(isTemp bool) (string, error) { m.ctrl.T.Helper()