diff --git a/pkg/sqlcache/informer/informer.go b/pkg/sqlcache/informer/informer.go index 926828f7..f3b50611 100644 --- a/pkg/sqlcache/informer/informer.go +++ b/pkg/sqlcache/informer/informer.go @@ -14,6 +14,7 @@ import ( sqlStore "github.com/rancher/steve/pkg/sqlcache/store" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/watch" @@ -30,11 +31,18 @@ type Informer struct { } type WatchOptions struct { + Filter WatchFilter +} + +type WatchFilter struct { + ID string + Selector labels.Selector + Namespace string } type ByOptionsLister interface { ListByOptions(ctx context.Context, lo *sqltypes.ListOptions, partitions []partition.Partition, namespace string) (*unstructured.UnstructuredList, int, string, error) - Watch(ctx context.Context, opts WatchOptions, eventsCh chan<- watch.Event) error + Watch(ctx context.Context, options WatchOptions, eventsCh chan<- watch.Event) error } // this is set to a var so that it can be overridden by test code for mocking purposes diff --git a/pkg/sqlcache/informer/listoption_indexer.go b/pkg/sqlcache/informer/listoption_indexer.go index 14ce8c51..ba4c4eb1 100644 --- a/pkg/sqlcache/informer/listoption_indexer.go +++ b/pkg/sqlcache/informer/listoption_indexer.go @@ -15,7 +15,9 @@ import ( "github.com/rancher/steve/pkg/sqlcache/db/transaction" "github.com/rancher/steve/pkg/sqlcache/sqltypes" "github.com/sirupsen/logrus" + "k8s.io/apimachinery/pkg/api/meta" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/watch" "k8s.io/client-go/tools/cache" @@ -32,7 +34,7 @@ type ListOptionIndexer struct { indexedFields []string watchersLock sync.RWMutex - watchers map[*watchKey]chan<- watch.Event + watchers map[*watchKey]*watcher addFieldsQuery string deleteFieldsByKeyQuery string @@ -111,7 +113,7 @@ func NewListOptionIndexer(ctx context.Context, fields [][]string, s Store, names Indexer: i, namespaced: namespaced, indexedFields: indexedFields, - watchers: make(map[*watchKey]chan<- watch.Event), + watchers: make(map[*watchKey]*watcher), } l.RegisterAfterAdd(l.addIndexFields) l.RegisterAfterAdd(l.notifyEventAdded) @@ -202,7 +204,7 @@ func NewListOptionIndexer(ctx context.Context, fields [][]string, s Store, names } func (l *ListOptionIndexer) Watch(ctx context.Context, opts WatchOptions, eventsCh chan<- watch.Event) error { - key := l.addWatcher(eventsCh) + key := l.addWatcher(eventsCh, opts.Filter) <-ctx.Done() l.removeWatcher(key) return nil @@ -212,10 +214,18 @@ type watchKey struct { _ bool // ensure watchKey is NOT zero-sized to get unique pointers } -func (l *ListOptionIndexer) addWatcher(eventCh chan<- watch.Event) *watchKey { +type watcher struct { + ch chan<- watch.Event + filter WatchFilter +} + +func (l *ListOptionIndexer) addWatcher(eventCh chan<- watch.Event, filter WatchFilter) *watchKey { key := new(watchKey) l.watchersLock.Lock() - l.watchers[key] = eventCh + l.watchers[key] = &watcher{ + ch: eventCh, + filter: filter, + } l.watchersLock.Unlock() return key } @@ -229,21 +239,42 @@ func (l *ListOptionIndexer) removeWatcher(key *watchKey) { /* Core methods */ func (l *ListOptionIndexer) notifyEventAdded(key string, obj any, tx transaction.Client) error { - return l.notifyEvent(watch.Added, obj, tx) + return l.notifyEvent(watch.Added, nil, obj, tx) } func (l *ListOptionIndexer) notifyEventModified(key string, obj any, tx transaction.Client) error { - return l.notifyEvent(watch.Modified, obj, tx) + oldObj, exists, err := l.GetByKey(key) + if err != nil { + return fmt.Errorf("error getting old object: %w", err) + } + + if !exists { + return fmt.Errorf("old object %q should be in store but was not", key) + } + + return l.notifyEvent(watch.Modified, oldObj, obj, tx) } func (l *ListOptionIndexer) notifyEventDeleted(key string, obj any, tx transaction.Client) error { - return l.notifyEvent(watch.Deleted, obj, tx) + oldObj, exists, err := l.GetByKey(key) + if err != nil { + return fmt.Errorf("error getting old object: %w", err) + } + + if !exists { + return fmt.Errorf("old object %q should be in store but was not", key) + } + return l.notifyEvent(watch.Deleted, oldObj, obj, tx) } -func (l *ListOptionIndexer) notifyEvent(eventType watch.EventType, obj any, tx transaction.Client) error { +func (l *ListOptionIndexer) notifyEvent(eventType watch.EventType, oldObj any, obj any, tx transaction.Client) error { l.watchersLock.RLock() for _, watcher := range l.watchers { - watcher <- watch.Event{ + if !matchWatch(watcher.filter.ID, watcher.filter.Namespace, watcher.filter.Selector, oldObj, obj) { + continue + } + + watcher.ch <- watch.Event{ Type: eventType, Object: obj.(runtime.Object), } @@ -1044,3 +1075,33 @@ func toUnstructuredList(items []any) *unstructured.UnstructuredList { } return result } + +func matchWatch(filterName string, filterNamespace string, filterSelector labels.Selector, oldObj any, obj any) bool { + matchOld := false + if oldObj != nil { + matchOld = matchFilter(filterName, filterNamespace, filterSelector, oldObj) + } + return matchOld || matchFilter(filterName, filterNamespace, filterSelector, obj) +} + +func matchFilter(filterName string, filterNamespace string, filterSelector labels.Selector, obj any) bool { + if obj == nil { + return false + } + metadata, err := meta.Accessor(obj) + if err != nil { + return false + } + if filterName != "" && filterName != metadata.GetName() { + return false + } + if filterNamespace != "" && filterNamespace != metadata.GetNamespace() { + return false + } + if filterSelector != nil { + if !filterSelector.Matches(labels.Set(metadata.GetLabels())) { + return false + } + } + return true +} diff --git a/pkg/sqlcache/informer/listoption_indexer_test.go b/pkg/sqlcache/informer/listoption_indexer_test.go index b88bd916..f087f2aa 100644 --- a/pkg/sqlcache/informer/listoption_indexer_test.go +++ b/pkg/sqlcache/informer/listoption_indexer_test.go @@ -23,6 +23,7 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/sets" watch "k8s.io/apimachinery/pkg/watch" @@ -1929,3 +1930,161 @@ func TestWatchMany(t *testing.T) { err = waitStopWatcher(errCh3) assert.NoError(t, err) } + +func TestWatchFilter(t *testing.T) { + startWatcher := func(ctx context.Context, loi *ListOptionIndexer, filter WatchFilter) (chan watch.Event, chan error) { + errCh := make(chan error, 1) + eventsCh := make(chan watch.Event, 100) + go func() { + watchErr := loi.Watch(ctx, WatchOptions{Filter: filter}, eventsCh) + errCh <- watchErr + }() + time.Sleep(100 * time.Millisecond) + return eventsCh, errCh + } + + waitStopWatcher := func(errCh chan error) error { + select { + case <-time.After(time.Second * 5): + return fmt.Errorf("not finished in time") + case err := <-errCh: + return err + } + } + + receiveEvents := func(eventsCh chan watch.Event) []watch.Event { + timer := time.NewTimer(time.Millisecond * 50) + var events []watch.Event + for { + select { + case <-timer.C: + return events + case ev := <-eventsCh: + events = append(events, ev) + } + } + } + + foo := &unstructured.Unstructured{} + foo.SetName("foo") + foo.SetNamespace("foo") + foo.SetLabels(map[string]string{ + "app": "foo", + }) + + fooUpdated := foo.DeepCopy() + fooUpdated.SetLabels(map[string]string{ + "app": "changed", + }) + + bar := &unstructured.Unstructured{} + bar.SetName("bar") + bar.SetNamespace("bar") + bar.SetLabels(map[string]string{ + "app": "bar", + }) + + appSelector, err := labels.Parse("app=foo") + assert.NoError(t, err) + + tests := []struct { + name string + filter WatchFilter + setupStore func(store cache.Store) error + expectedEvents []watch.Event + }{ + { + name: "namespace filter", + filter: WatchFilter{Namespace: "foo"}, + setupStore: func(store cache.Store) error { + err := store.Add(foo) + if err != nil { + return err + } + err = store.Add(bar) + if err != nil { + return err + } + return nil + }, + expectedEvents: []watch.Event{{Type: watch.Added, Object: foo}}, + }, + { + name: "selector filter", + filter: WatchFilter{Selector: appSelector}, + setupStore: func(store cache.Store) error { + err := store.Add(foo) + if err != nil { + return err + } + err = store.Add(bar) + if err != nil { + return err + } + err = store.Update(fooUpdated) + if err != nil { + return err + } + return nil + }, + expectedEvents: []watch.Event{ + {Type: watch.Added, Object: foo}, + {Type: watch.Modified, Object: fooUpdated}, + }, + }, + { + name: "id filter", + filter: WatchFilter{ID: "foo"}, + setupStore: func(store cache.Store) error { + err := store.Add(foo) + if err != nil { + return err + } + err = store.Add(bar) + if err != nil { + return err + } + err = store.Update(fooUpdated) + if err != nil { + return err + } + err = store.Update(foo) + if err != nil { + return err + } + return nil + }, + expectedEvents: []watch.Event{ + {Type: watch.Added, Object: foo}, + {Type: watch.Modified, Object: fooUpdated}, + {Type: watch.Modified, Object: foo}, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + loi, err := makeListOptionIndexer(ctx, [][]string{{"metadata", "somefield"}}) + assert.NoError(t, err) + + wCh, errCh := startWatcher(ctx, loi, WatchFilter{ + Namespace: "foo", + }) + + if test.setupStore != nil { + err = test.setupStore(loi) + assert.NoError(t, err) + } + + events := receiveEvents(wCh) + assert.Equal(t, test.expectedEvents, events) + + cancel() + err = waitStopWatcher(errCh) + assert.NoError(t, err) + + }) + } + +} diff --git a/pkg/stores/sqlproxy/proxy_store.go b/pkg/stores/sqlproxy/proxy_store.go index dbe8c15b..c2d4221a 100644 --- a/pkg/stores/sqlproxy/proxy_store.go +++ b/pkg/stores/sqlproxy/proxy_store.go @@ -19,6 +19,7 @@ import ( "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" apitypes "k8s.io/apimachinery/pkg/types" @@ -36,6 +37,7 @@ import ( "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/rancher/steve/pkg/sqlcache/sqltypes" "github.com/rancher/wrangler/v3/pkg/data" + "github.com/rancher/wrangler/v3/pkg/kv" "github.com/rancher/wrangler/v3/pkg/schemas" "github.com/rancher/wrangler/v3/pkg/schemas/validation" "github.com/rancher/wrangler/v3/pkg/summary" @@ -550,10 +552,30 @@ func (s *Store) watch(apiOp *types.APIRequest, schema *types.APISchema, w types. return nil, err } - result := make(chan watch.Event, 1000) + var selector labels.Selector + if w.Selector != "" { + selector, err = labels.Parse(w.Selector) + if err != nil { + return nil, fmt.Errorf("invalid selector: %w", err) + } + } + + result := make(chan watch.Event) go func() { ctx := apiOp.Context() - err := inf.ByOptionsLister.Watch(ctx, informer.WatchOptions{}, result) + idNamespace, _ := kv.RSplit(w.ID, "/") + if idNamespace == "" { + idNamespace = apiOp.Namespace + } + + opts := informer.WatchOptions{ + Filter: informer.WatchFilter{ + ID: w.ID, + Namespace: idNamespace, + Selector: selector, + }, + } + err := inf.ByOptionsLister.Watch(ctx, opts, result) if err != nil { logrus.Error(err) }