diff --git a/pkg/kubelet/cm/cpumanager/containermap/container_map.go b/pkg/kubelet/cm/cpumanager/containermap/container_map.go index 525ec75131e..a173d2a9ba2 100644 --- a/pkg/kubelet/cm/cpumanager/containermap/container_map.go +++ b/pkg/kubelet/cm/cpumanager/containermap/container_map.go @@ -22,51 +22,44 @@ import ( "k8s.io/api/core/v1" ) -// ContainerMap maps (containerID)->(podUID, containerName) -type ContainerMap map[string]map[string]string +// ContainerMap maps (containerID)->(*v1.Pod, *v1.Container) +type ContainerMap map[string]struct { + pod *v1.Pod + container *v1.Container +} // NewContainerMap creates a new ContainerMap struct func NewContainerMap() ContainerMap { return make(ContainerMap) } -// Add adds a mapping of (containerID)->(podUID, containerName) to the ContainerMap +// Add adds a mapping of (containerID)->(*v1.Pod, *v1.Container) to the ContainerMap func (cm ContainerMap) Add(p *v1.Pod, c *v1.Container, containerID string) { - podUID := string(p.UID) - if _, exists := cm[podUID]; !exists { - cm[podUID] = make(map[string]string) - } - cm[podUID][c.Name] = containerID + cm[containerID] = struct { + pod *v1.Pod + container *v1.Container + }{p, c} } -// Remove removes a mapping of (containerID)->(podUID, containerName) from the ContainerMap +// Remove removes a mapping of (containerID)->(*v1.Pod, *.v1.Container) from the ContainerMap func (cm ContainerMap) Remove(containerID string) { - found := false - for podUID := range cm { - for containerName := range cm[podUID] { - if containerID == cm[podUID][containerName] { - delete(cm[podUID], containerName) - found = true - break - } - } - if len(cm[podUID]) == 0 { - delete(cm, podUID) - } - if found { - break - } - } + delete(cm, containerID) } -// Get retrieves a ContainerID from the ContainerMap -func (cm ContainerMap) Get(p *v1.Pod, c *v1.Container) (string, error) { - podUID := string(p.UID) - if _, exists := cm[podUID]; !exists { - return "", fmt.Errorf("pod %s not in ContainerMap", podUID) +// GetContainerID retrieves a ContainerID from the ContainerMap +func (cm ContainerMap) GetContainerID(p *v1.Pod, c *v1.Container) (string, error) { + for key, val := range cm { + if val.pod.UID == p.UID && val.container.Name == c.Name { + return key, nil + } } - if _, exists := cm[podUID][c.Name]; !exists { - return "", fmt.Errorf("container %s not in ContainerMap for pod %s", c.Name, podUID) - } - return cm[podUID][c.Name], nil + return "", fmt.Errorf("container %s not in ContainerMap for pod %s", c.Name, p.UID) +} + +// GetContainerRef retrieves a (*v1.Pod, *v1.Container) pair from the ContainerMap +func (cm ContainerMap) GetContainerRef(containerID string) (*v1.Pod, *v1.Container, error) { + if _, exists := cm[containerID]; !exists { + return nil, nil, fmt.Errorf("containerID %s not in ContainerMap", containerID) + } + return cm[containerID].pod, cm[containerID].container, nil } diff --git a/pkg/kubelet/cm/cpumanager/containermap/container_map_test.go b/pkg/kubelet/cm/cpumanager/containermap/container_map_test.go index de872cd6df0..d48ca8c9cce 100644 --- a/pkg/kubelet/cm/cpumanager/containermap/container_map_test.go +++ b/pkg/kubelet/cm/cpumanager/containermap/container_map_test.go @@ -47,13 +47,25 @@ func TestContainerMap(t *testing.T) { container := v1.Container{Name: tc.containerNames[i]} cm.Add(&pod, &container, tc.containerIDs[i]) - containerID, err := cm.Get(&pod, &container) + + containerID, err := cm.GetContainerID(&pod, &container) if err != nil { - t.Errorf("error adding and retrieving container: %v", err) + t.Errorf("error adding and retrieving containerID: %v", err) } if containerID != tc.containerIDs[i] { t.Errorf("mismatched containerIDs %v, %v", containerID, tc.containerIDs[i]) } + + podRef, containerRef, err := cm.GetContainerRef(containerID) + if err != nil { + t.Errorf("error retrieving container reference: %v", err) + } + if podRef != &pod { + t.Errorf("mismatched pod reference %v, %v", pod.UID, podRef.UID) + } + if containerRef != &container { + t.Errorf("mismatched container reference %v, %v", container.Name, containerRef.Name) + } } // Remove all entries from the containerMap, checking proper removal of @@ -61,7 +73,7 @@ func TestContainerMap(t *testing.T) { for i := range tc.containerNames { container := v1.Container{Name: tc.containerNames[i]} cm.Remove(tc.containerIDs[i]) - containerID, err := cm.Get(&pod, &container) + containerID, err := cm.GetContainerID(&pod, &container) if err == nil { t.Errorf("unexpected retrieval of containerID after removal: %v", containerID) } diff --git a/pkg/kubelet/cm/cpumanager/policy_static.go b/pkg/kubelet/cm/cpumanager/policy_static.go index 85aec3b8ba6..ed022e6827f 100644 --- a/pkg/kubelet/cm/cpumanager/policy_static.go +++ b/pkg/kubelet/cm/cpumanager/policy_static.go @@ -211,7 +211,7 @@ func (p *staticPolicy) AddContainer(s state.State, pod *v1.Pod, container *v1.Co // container is run. for _, initContainer := range pod.Spec.InitContainers { if container.Name != initContainer.Name { - initContainerID, err := p.containerMap.Get(pod, &initContainer) + initContainerID, err := p.containerMap.GetContainerID(pod, &initContainer) if err != nil { continue }