storage/driver: plumb contexts into middlewares

Signed-off-by: Cory Snider <csnider@mirantis.com>
This commit is contained in:
Cory Snider
2023-10-27 17:46:09 -04:00
parent b45b6d18b8
commit b4dc4f3474
8 changed files with 72 additions and 48 deletions

View File

@@ -48,7 +48,7 @@ var _ storagedriver.StorageDriver = &cloudFrontStorageMiddleware{}
// default value. "aws", only aws IP goes to S3 directly. "awsregion", only
// regions listed in awsregion options goes to S3 directly
// - awsregion: a comma separated string of AWS regions.
func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) {
func newCloudFrontStorageMiddleware(ctx context.Context, storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) {
// parse baseurl
base, ok := options["baseurl"]
if !ok {
@@ -157,7 +157,10 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
case "", "none":
awsIPs = nil
case "aws":
awsIPs = newAWSIPs(ipRangesURL, updateFrequency, nil)
awsIPs, err = newAWSIPs(ctx, ipRangesURL, updateFrequency, nil)
if err != nil {
return nil, err
}
case "awsregion":
var awsRegion []string
if i, ok := options["awsregion"]; ok {
@@ -165,7 +168,10 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
for _, awsRegions := range strings.Split(regions, ",") {
awsRegion = append(awsRegion, strings.ToLower(strings.TrimSpace(awsRegions)))
}
awsIPs = newAWSIPs(ipRangesURL, updateFrequency, awsRegion)
awsIPs, err = newAWSIPs(ctx, ipRangesURL, updateFrequency, awsRegion)
if err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf("awsRegion must be a comma separated string of valid aws regions")
}

View File

@@ -1,6 +1,7 @@
package middleware
import (
"context"
"os"
"testing"
@@ -15,7 +16,7 @@ var _ = check.Suite(&MiddlewareSuite{})
func (s *MiddlewareSuite) TestNoConfig(c *check.C) {
options := make(map[string]interface{})
_, err := newCloudFrontStorageMiddleware(nil, options)
_, err := newCloudFrontStorageMiddleware(context.Background(), nil, options)
c.Assert(err, check.ErrorMatches, "no baseurl provided")
}
@@ -48,7 +49,7 @@ pZeMRablbPQdp8/1NyIwimq1VlG0ohQ4P6qhW7E09ZMC
defer os.Remove(file.Name())
options["privatekey"] = file.Name()
options["keypairid"] = "test"
storageDriver, err := newCloudFrontStorageMiddleware(nil, options)
storageDriver, err := newCloudFrontStorageMiddleware(context.Background(), nil, options)
if err != nil {
t.Fatal(err)
}

View File

@@ -3,6 +3,7 @@ package middleware
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
@@ -23,18 +24,21 @@ const (
// newAWSIPs returns a New awsIP object.
// If awsRegion is `nil`, it accepts any region. Otherwise, it only allow the regions specified
func newAWSIPs(host string, updateFrequency time.Duration, awsRegion []string) *awsIPs {
func newAWSIPs(ctx context.Context, host string, updateFrequency time.Duration, awsRegion []string) (*awsIPs, error) {
ips := &awsIPs{
host: host,
updateFrequency: updateFrequency,
awsRegion: awsRegion,
updaterStopChan: make(chan bool),
}
if err := ips.tryUpdate(); err != nil {
dcontext.GetLogger(context.Background()).WithError(err).Warn("failed to update AWS IP")
if err := ips.tryUpdate(ctx); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err
}
dcontext.GetLogger(ctx).WithError(err).Warn("failed to update AWS IP")
}
go ips.updater()
return ips
return ips, nil
}
// awsIPs tracks a list of AWS ips, filtered by awsRegion
@@ -61,9 +65,13 @@ type prefixEntry struct {
Service string `json:"service"`
}
func fetchAWSIPs(url string) (awsIPResponse, error) {
func fetchAWSIPs(ctx context.Context, url string) (awsIPResponse, error) {
var response awsIPResponse
resp, err := http.Get(url)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return response, err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return response, err
}
@@ -83,8 +91,8 @@ func fetchAWSIPs(url string) (awsIPResponse, error) {
// tryUpdate attempts to download the new set of ip addresses.
// tryUpdate must be thread safe with contains
func (s *awsIPs) tryUpdate() error {
response, err := fetchAWSIPs(s.host)
func (s *awsIPs) tryUpdate(ctx context.Context) error {
response, err := fetchAWSIPs(ctx, s.host)
if err != nil {
return err
}
@@ -135,17 +143,18 @@ func (s *awsIPs) tryUpdate() error {
// This function is meant to be run in a background goroutine.
// It will periodically update the ips from aws.
func (s *awsIPs) updater() {
ctx := context.TODO()
defer close(s.updaterStopChan)
for {
time.Sleep(s.updateFrequency)
select {
case <-s.updaterStopChan:
dcontext.GetLogger(context.Background()).Info("aws ip updater received stop signal")
dcontext.GetLogger(ctx).Info("aws ip updater received stop signal")
return
default:
err := s.tryUpdate()
err := s.tryUpdate(ctx)
if err != nil {
dcontext.GetLogger(context.Background()).WithError(err).Error("git AWS IP")
dcontext.GetLogger(ctx).WithError(err).Error("git AWS IP")
}
}
}

View File

@@ -62,7 +62,7 @@ func TestS3TryUpdate(t *testing.T) {
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
ips, _ := newAWSIPs(context.Background(), serverIPRanges(server), time.Hour, nil)
assertEqual(t, 1, len(ips.ipv4))
assertEqual(t, 0, len(ips.ipv6))
@@ -77,8 +77,9 @@ func TestMatchIPV6(t *testing.T) {
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
ips.tryUpdate()
ctx := context.Background()
ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil)
ips.tryUpdate(ctx)
assertEqual(t, true, ips.contains(net.ParseIP("ff00::")))
assertEqual(t, 1, len(ips.ipv6))
assertEqual(t, 0, len(ips.ipv4))
@@ -93,8 +94,9 @@ func TestMatchIPV4(t *testing.T) {
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
ips.tryUpdate()
ctx := context.Background()
ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil)
ips.tryUpdate(ctx)
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
@@ -112,8 +114,9 @@ func TestMatchIPV4_2(t *testing.T) {
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
ips.tryUpdate()
ctx := context.Background()
ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil)
ips.tryUpdate(ctx)
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
@@ -131,8 +134,9 @@ func TestMatchIPV4WithRegionMatched(t *testing.T) {
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-east-1"})
ips.tryUpdate()
ctx := context.Background()
ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, []string{"us-east-1"})
ips.tryUpdate(ctx)
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
@@ -150,8 +154,9 @@ func TestMatchIPV4WithRegionMatch_2(t *testing.T) {
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2", "us-east-1"})
ips.tryUpdate()
ctx := context.Background()
ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, []string{"us-west-2", "us-east-1"})
ips.tryUpdate(ctx)
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
@@ -169,8 +174,9 @@ func TestMatchIPV4WithRegionNotMatched(t *testing.T) {
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2"})
ips.tryUpdate()
ctx := context.Background()
ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, []string{"us-west-2"})
ips.tryUpdate(ctx)
assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.0")))
assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.1")))
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
@@ -187,8 +193,9 @@ func TestInvalidData(t *testing.T) {
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
ips.tryUpdate()
ctx := context.Background()
ips, _ := newAWSIPs(ctx, serverIPRanges(server), time.Hour, nil)
ips.tryUpdate(ctx)
assertEqual(t, 1, len(ips.ipv4))
}
@@ -205,7 +212,7 @@ func TestInvalidNetworkType(t *testing.T) {
})
defer server.Close()
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
ips, _ := newAWSIPs(context.Background(), serverIPRanges(server), time.Hour, nil)
assertEqual(t, 0, len(ips.getCandidateNetworks(make([]byte, 17)))) // 17 bytes does not correspond to any net type
assertEqual(t, 1, len(ips.getCandidateNetworks(make([]byte, 4)))) // netv4 networks
assertEqual(t, 2, len(ips.getCandidateNetworks(make([]byte, 16)))) // netv6 networks
@@ -226,7 +233,7 @@ func TestParsing(t *testing.T) {
t.Parallel()
server := httptest.NewServer(rawMockHandler)
defer server.Close()
schema, err := fetchAWSIPs(server.URL)
schema, err := fetchAWSIPs(context.Background(), server.URL)
assertEqual(t, nil, err)
assertEqual(t, 1, len(schema.Prefixes))
@@ -253,7 +260,7 @@ func TestUpdateCalledRegularly(t *testing.T) {
rw.Write([]byte("ok"))
}))
defer server.Close()
newAWSIPs(fmt.Sprintf("%s/", server.URL), time.Second, nil)
newAWSIPs(context.Background(), fmt.Sprintf("%s/", server.URL), time.Second, nil)
time.Sleep(time.Second*4 + time.Millisecond*500)
if updateCount < 4 {
t.Errorf("Update should have been called at least 4 times, actual=%d", updateCount)
@@ -384,7 +391,7 @@ func BenchmarkContainsRandom(b *testing.B) {
}
func BenchmarkContainsProd(b *testing.B) {
ips := newAWSIPs(defaultIPRangesURL, defaultUpdateFrequency, nil)
ips, _ := newAWSIPs(context.Background(), defaultIPRangesURL, defaultUpdateFrequency, nil)
ipv4 := make([][]byte, b.N)
ipv6 := make([][]byte, b.N)
for i := 0; i < b.N; i++ {