kubectl diff: refactor tracker into a separate type

This means that we can reuse the logic even if we swap out the pruner.
This commit is contained in:
justinsb 2023-03-14 12:06:44 +00:00
parent 6a31757f45
commit 9c5c8b243d
2 changed files with 34 additions and 23 deletions

View File

@ -114,7 +114,9 @@ type DiffOptions struct {
EnforceNamespace bool EnforceNamespace bool
Builder *resource.Builder Builder *resource.Builder
Diff *DiffProgram Diff *DiffProgram
pruner *pruner
pruner *pruner
tracker *tracker
} }
func NewDiffOptions(ioStreams genericclioptions.IOStreams) *DiffOptions { func NewDiffOptions(ioStreams genericclioptions.IOStreams) *DiffOptions {
@ -659,6 +661,7 @@ func (o *DiffOptions) Complete(f cmdutil.Factory, cmd *cobra.Command, args []str
if err != nil { if err != nil {
return err return err
} }
o.tracker = newTracker()
o.pruner = newPruner(o.DynamicClient, mapper, resources, o.Selector) o.pruner = newPruner(o.DynamicClient, mapper, resources, o.Selector)
} }
@ -723,8 +726,8 @@ func (o *DiffOptions) Run() error {
IOStreams: o.Diff.IOStreams, IOStreams: o.Diff.IOStreams,
} }
if o.pruner != nil { if o.tracker != nil {
o.pruner.MarkVisited(info) o.tracker.MarkVisited(info)
} }
err = differ.Diff(obj, printer, o.ShowManagedFields) err = differ.Diff(obj, printer, o.ShowManagedFields)
@ -739,7 +742,7 @@ func (o *DiffOptions) Run() error {
}) })
if o.pruner != nil { if o.pruner != nil {
prunedObjs, err := o.pruner.pruneAll(o.CmdNamespace != "") prunedObjs, err := o.pruner.pruneAll(o.tracker, o.CmdNamespace != "")
if err != nil { if err != nil {
klog.Warningf("pruning failed and could not be evaluated err: %v", err) klog.Warningf("pruning failed and could not be evaluated err: %v", err)
} }

View File

@ -31,37 +31,45 @@ import (
"k8s.io/kubectl/pkg/util/prune" "k8s.io/kubectl/pkg/util/prune"
) )
type tracker struct {
visitedUids sets.Set[types.UID]
visitedNamespaces sets.Set[string]
}
func newTracker() *tracker {
return &tracker{
visitedUids: sets.New[types.UID](),
visitedNamespaces: sets.New[string](),
}
}
type pruner struct { type pruner struct {
mapper meta.RESTMapper mapper meta.RESTMapper
dynamicClient dynamic.Interface dynamicClient dynamic.Interface
visitedUids sets.Set[types.UID] labelSelector string
visitedNamespaces sets.Set[string] resources []prune.Resource
labelSelector string
resources []prune.Resource
} }
func newPruner(dc dynamic.Interface, m meta.RESTMapper, r []prune.Resource, selector string) *pruner { func newPruner(dc dynamic.Interface, m meta.RESTMapper, r []prune.Resource, selector string) *pruner {
return &pruner{ return &pruner{
visitedUids: sets.New[types.UID](), dynamicClient: dc,
visitedNamespaces: sets.New[string](), mapper: m,
dynamicClient: dc, resources: r,
mapper: m, labelSelector: selector,
resources: r,
labelSelector: selector,
} }
} }
func (p *pruner) pruneAll(namespaceSpecified bool) ([]runtime.Object, error) { func (p *pruner) pruneAll(tracker *tracker, namespaceSpecified bool) ([]runtime.Object, error) {
var allPruned []runtime.Object var allPruned []runtime.Object
namespacedRESTMappings, nonNamespacedRESTMappings, err := prune.GetRESTMappings(p.mapper, p.resources, namespaceSpecified) namespacedRESTMappings, nonNamespacedRESTMappings, err := prune.GetRESTMappings(p.mapper, p.resources, namespaceSpecified)
if err != nil { if err != nil {
return allPruned, fmt.Errorf("error retrieving RESTMappings to prune: %v", err) return allPruned, fmt.Errorf("error retrieving RESTMappings to prune: %v", err)
} }
for n := range p.visitedNamespaces { for n := range tracker.visitedNamespaces {
for _, m := range namespacedRESTMappings { for _, m := range namespacedRESTMappings {
if pobjs, err := p.prune(n, m); err != nil { if pobjs, err := p.prune(tracker, n, m); err != nil {
return pobjs, fmt.Errorf("error pruning namespaced object %v: %v", m.GroupVersionKind, err) return pobjs, fmt.Errorf("error pruning namespaced object %v: %v", m.GroupVersionKind, err)
} else { } else {
allPruned = append(allPruned, pobjs...) allPruned = append(allPruned, pobjs...)
@ -69,7 +77,7 @@ func (p *pruner) pruneAll(namespaceSpecified bool) ([]runtime.Object, error) {
} }
} }
for _, m := range nonNamespacedRESTMappings { for _, m := range nonNamespacedRESTMappings {
if pobjs, err := p.prune(metav1.NamespaceNone, m); err != nil { if pobjs, err := p.prune(tracker, metav1.NamespaceNone, m); err != nil {
return allPruned, fmt.Errorf("error pruning nonNamespaced object %v: %v", m.GroupVersionKind, err) return allPruned, fmt.Errorf("error pruning nonNamespaced object %v: %v", m.GroupVersionKind, err)
} else { } else {
allPruned = append(allPruned, pobjs...) allPruned = append(allPruned, pobjs...)
@ -79,7 +87,7 @@ func (p *pruner) pruneAll(namespaceSpecified bool) ([]runtime.Object, error) {
return allPruned, nil return allPruned, nil
} }
func (p *pruner) prune(namespace string, mapping *meta.RESTMapping) ([]runtime.Object, error) { func (p *pruner) prune(tracker *tracker, namespace string, mapping *meta.RESTMapping) ([]runtime.Object, error) {
objList, err := p.dynamicClient.Resource(mapping.Resource). objList, err := p.dynamicClient.Resource(mapping.Resource).
Namespace(namespace). Namespace(namespace).
List(context.TODO(), metav1.ListOptions{ List(context.TODO(), metav1.ListOptions{
@ -105,7 +113,7 @@ func (p *pruner) prune(namespace string, mapping *meta.RESTMapping) ([]runtime.O
continue continue
} }
uid := metadata.GetUID() uid := metadata.GetUID()
if p.visitedUids.Has(uid) { if tracker.visitedUids.Has(uid) {
continue continue
} }
@ -115,14 +123,14 @@ func (p *pruner) prune(namespace string, mapping *meta.RESTMapping) ([]runtime.O
} }
// MarkVisited marks visited namespaces and uids // MarkVisited marks visited namespaces and uids
func (p *pruner) MarkVisited(info *resource.Info) { func (t *tracker) MarkVisited(info *resource.Info) {
if info.Namespaced() { if info.Namespaced() {
p.visitedNamespaces.Insert(info.Namespace) t.visitedNamespaces.Insert(info.Namespace)
} }
metadata, err := meta.Accessor(info.Object) metadata, err := meta.Accessor(info.Object)
if err != nil { if err != nil {
return return
} }
p.visitedUids.Insert(metadata.GetUID()) t.visitedUids.Insert(metadata.GetUID())
} }