diff --git a/pkg/util/bandwidth/linux.go b/pkg/util/bandwidth/linux.go index 6413930af39..9949aa3d179 100644 --- a/pkg/util/bandwidth/linux.go +++ b/pkg/util/bandwidth/linux.go @@ -101,7 +101,7 @@ func hexCIDR(cidr string) (string, error) { return "", err } ip = ip.Mask(ipnet.Mask) - hexIP := hex.EncodeToString([]byte(ip.To4())) + hexIP := hex.EncodeToString([]byte(ip)) hexMask := ipnet.Mask.String() return hexIP + "/" + hexMask, nil } @@ -119,6 +119,9 @@ func asciiCIDR(cidr string) (string, error) { ip := net.IP(ipData) maskData, err := hex.DecodeString(parts[1]) + if err != nil { + return "", err + } mask := net.IPMask(maskData) size, _ := mask.Size() diff --git a/pkg/util/bandwidth/linux_test.go b/pkg/util/bandwidth/linux_test.go index e005d65427e..980f8f845c7 100644 --- a/pkg/util/bandwidth/linux_test.go +++ b/pkg/util/bandwidth/linux_test.go @@ -94,19 +94,33 @@ func TestNextClassID(t *testing.T) { func TestHexCIDR(t *testing.T) { tests := []struct { + name string input string output string expectErr bool }{ { - input: "1.2.0.0/16", + name: "IPv4 masked", + input: "1.2.3.4/16", output: "01020000/ffff0000", }, { + name: "IPv4 host", input: "172.17.0.2/32", output: "ac110002/ffffffff", }, { + name: "IPv6 masked", + input: "2001:dead:beef::cafe/64", + output: "2001deadbeef00000000000000000000/ffffffffffffffff0000000000000000", + }, + { + name: "IPv6 host", + input: "2001::5/128", + output: "20010000000000000000000000000005/ffffffffffffffffffffffffffffffff", + }, + { + name: "invalid CIDR", input: "foo", expectErr: true, }, @@ -115,21 +129,76 @@ func TestHexCIDR(t *testing.T) { output, err := hexCIDR(test.input) if test.expectErr { if err == nil { - t.Error("unexpected non-error") + t.Errorf("case %s: unexpected non-error", test.name) } } else { if err != nil { - t.Errorf("unexpected error: %v", err) + t.Errorf("case %s: unexpected error: %v", test.name, err) } if output != test.output { - t.Errorf("expected: %s, saw: %s", test.output, output) + t.Errorf("case %s: expected: %s, saw: %s", + test.name, test.output, output) } - input, err := asciiCIDR(output) + } + } +} + +func TestAsciiCIDR(t *testing.T) { + tests := []struct { + name string + input string + output string + expectErr bool + }{ + { + name: "IPv4", + input: "01020000/ffff0000", + output: "1.2.0.0/16", + }, + { + name: "IPv4 host", + input: "ac110002/ffffffff", + output: "172.17.0.2/32", + }, + { + name: "IPv6", + input: "2001deadbeef00000000000000000000/ffffffffffffffff0000000000000000", + output: "2001:dead:beef::/64", + }, + { + name: "IPv6 host", + input: "20010000000000000000000000000005/ffffffffffffffffffffffffffffffff", + output: "2001::5/128", + }, + { + name: "invalid CIDR", + input: "malformed", + expectErr: true, + }, + { + name: "non-hex IP", + input: "nonhex/32", + expectErr: true, + }, + { + name: "non-hex mask", + input: "01020000/badmask", + expectErr: true, + }, + } + for _, test := range tests { + output, err := asciiCIDR(test.input) + if test.expectErr { + if err == nil { + t.Errorf("case %s: unexpected non-error", test.name) + } + } else { if err != nil { - t.Errorf("unexpected error: %v", err) + t.Errorf("case %s: unexpected error: %v", test.name, err) } - if input != test.input { - t.Errorf("expected: %s, saw: %s", test.input, input) + if output != test.output { + t.Errorf("case %s: expected: %s, saw: %s", + test.name, test.output, output) } } }