Replace custom Redis config struct with go-redis UniversalOptions (adds sentinel & cluster support) (#4306)

This commit is contained in:
Milos Gajdos
2024-07-04 16:00:37 +01:00
committed by GitHub
7 changed files with 299 additions and 157 deletions

View File

@@ -8,6 +8,8 @@ import (
"reflect"
"strings"
"time"
"github.com/redis/go-redis/v9"
)
// Configuration is a versioned registry configuration, intended to be provided by a yaml file, and
@@ -259,44 +261,6 @@ type FileChecker struct {
Threshold int `yaml:"threshold,omitempty"`
}
// Redis configures the redis pool available to the registry webapp.
type Redis struct {
// Addr specifies the redis instance available to the application.
Addr string `yaml:"addr,omitempty"`
// Usernames can be used as a finer-grained permission control since the introduction of the redis 6.0.
Username string `yaml:"username,omitempty"`
// Password string to use when making a connection.
Password string `yaml:"password,omitempty"`
// DB specifies the database to connect to on the redis instance.
DB int `yaml:"db,omitempty"`
// TLS configures settings for redis in-transit encryption
TLS struct {
Enabled bool `yaml:"enabled,omitempty"`
} `yaml:"tls,omitempty"`
DialTimeout time.Duration `yaml:"dialtimeout,omitempty"` // timeout for connect
ReadTimeout time.Duration `yaml:"readtimeout,omitempty"` // timeout for reads of data
WriteTimeout time.Duration `yaml:"writetimeout,omitempty"` // timeout for writes of data
// Pool configures the behavior of the redis connection pool.
Pool struct {
// MaxIdle sets the maximum number of idle connections.
MaxIdle int `yaml:"maxidle,omitempty"`
// MaxActive sets the maximum number of connections that should be
// opened before blocking a connection request.
MaxActive int `yaml:"maxactive,omitempty"`
// IdleTimeout sets the amount time to wait before closing
// inactive connections.
IdleTimeout time.Duration `yaml:"idletimeout,omitempty"`
} `yaml:"pool,omitempty"`
}
// HTTPChecker is a type of entry in the health section for checking HTTP URIs.
type HTTPChecker struct {
// Timeout is the duration to wait before timing out the HTTP request
@@ -750,3 +714,172 @@ func Parse(rd io.Reader) (*Configuration, error) {
return config, nil
}
type RedisOptions = redis.UniversalOptions
type RedisTLSOptions struct {
Certificate string `yaml:"certificate,omitempty"`
Key string `yaml:"key,omitempty"`
ClientCAs []string `yaml:"clientcas,omitempty"`
}
type Redis struct {
Options RedisOptions `yaml:",inline"`
TLS RedisTLSOptions `yaml:"tls,omitempty"`
}
func (c Redis) MarshalYAML() (interface{}, error) {
fields := make(map[string]interface{})
val := reflect.ValueOf(c.Options)
typ := val.Type()
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
fieldValue := val.Field(i)
// ignore funcs fields in redis.UniversalOptions
if fieldValue.Kind() == reflect.Func {
continue
}
fields[strings.ToLower(field.Name)] = fieldValue.Interface()
}
// Add TLS fields if they're not empty
if c.TLS.Certificate != "" || c.TLS.Key != "" || len(c.TLS.ClientCAs) > 0 {
fields["tls"] = c.TLS
}
return fields, nil
}
func (c *Redis) UnmarshalYAML(unmarshal func(interface{}) error) error {
var fields map[string]interface{}
err := unmarshal(&fields)
if err != nil {
return err
}
val := reflect.ValueOf(&c.Options).Elem()
typ := val.Type()
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
fieldName := strings.ToLower(field.Name)
if value, ok := fields[fieldName]; ok {
fieldValue := val.Field(i)
if fieldValue.CanSet() {
switch field.Type {
case reflect.TypeOf(time.Duration(0)):
durationStr, ok := value.(string)
if !ok {
return fmt.Errorf("invalid duration value for field: %s", fieldName)
}
duration, err := time.ParseDuration(durationStr)
if err != nil {
return fmt.Errorf("failed to parse duration for field: %s, error: %v", fieldName, err)
}
fieldValue.Set(reflect.ValueOf(duration))
default:
if err := setFieldValue(fieldValue, value); err != nil {
return fmt.Errorf("failed to set value for field: %s, error: %v", fieldName, err)
}
}
}
}
}
// Handle TLS fields
if tlsData, ok := fields["tls"]; ok {
tlsMap, ok := tlsData.(map[interface{}]interface{})
if !ok {
return fmt.Errorf("invalid TLS data structure")
}
if cert, ok := tlsMap["certificate"]; ok {
var isString bool
c.TLS.Certificate, isString = cert.(string)
if !isString {
return fmt.Errorf("Redis TLS certificate must be a string")
}
}
if key, ok := tlsMap["key"]; ok {
var isString bool
c.TLS.Key, isString = key.(string)
if !isString {
return fmt.Errorf("Redis TLS (private) key must be a string")
}
}
if cas, ok := tlsMap["clientcas"]; ok {
caList, ok := cas.([]interface{})
if !ok {
return fmt.Errorf("invalid clientcas data structure")
}
for _, ca := range caList {
if caStr, ok := ca.(string); ok {
c.TLS.ClientCAs = append(c.TLS.ClientCAs, caStr)
}
}
}
}
return nil
}
func setFieldValue(field reflect.Value, value interface{}) error {
if value == nil {
return nil
}
switch field.Kind() {
case reflect.String:
stringValue, ok := value.(string)
if !ok {
return fmt.Errorf("failed to convert value to string")
}
field.SetString(stringValue)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
intValue, ok := value.(int)
if !ok {
return fmt.Errorf("failed to convert value to integer")
}
field.SetInt(int64(intValue))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
uintValue, ok := value.(uint)
if !ok {
return fmt.Errorf("failed to convert value to unsigned integer")
}
field.SetUint(uint64(uintValue))
case reflect.Float32, reflect.Float64:
floatValue, ok := value.(float64)
if !ok {
return fmt.Errorf("failed to convert value to float")
}
field.SetFloat(floatValue)
case reflect.Bool:
boolValue, ok := value.(bool)
if !ok {
return fmt.Errorf("failed to convert value to boolean")
}
field.SetBool(boolValue)
case reflect.Slice:
slice := reflect.MakeSlice(field.Type(), 0, 0)
valueSlice, ok := value.([]interface{})
if !ok {
return fmt.Errorf("failed to convert value to slice")
}
for _, item := range valueSlice {
sliceValue := reflect.New(field.Type().Elem()).Elem()
if err := setFieldValue(sliceValue, item); err != nil {
return err
}
slice = reflect.Append(slice, sliceValue)
}
field.Set(slice)
default:
return fmt.Errorf("unsupported field type: %v", field.Type())
}
return nil
}

View File

@@ -8,6 +8,7 @@ import (
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/suite"
"gopkg.in/yaml.v2"
)
@@ -134,22 +135,23 @@ var configStruct = Configuration{
},
},
Redis: Redis{
Addr: "localhost:6379",
Username: "alice",
Password: "123456",
DB: 1,
Pool: struct {
MaxIdle int `yaml:"maxidle,omitempty"`
MaxActive int `yaml:"maxactive,omitempty"`
IdleTimeout time.Duration `yaml:"idletimeout,omitempty"`
}{
MaxIdle: 16,
MaxActive: 64,
IdleTimeout: time.Second * 300,
Options: redis.UniversalOptions{
Addrs: []string{"localhost:6379"},
Username: "alice",
Password: "123456",
DB: 1,
MaxIdleConns: 16,
PoolSize: 64,
ConnMaxIdleTime: time.Second * 300,
DialTimeout: time.Millisecond * 10,
ReadTimeout: time.Millisecond * 10,
WriteTimeout: time.Millisecond * 10,
},
TLS: RedisTLSOptions{
Certificate: "/foo/cert.crt",
Key: "/foo/key.pem",
ClientCAs: []string{"/path/to/ca.pem"},
},
DialTimeout: time.Millisecond * 10,
ReadTimeout: time.Millisecond * 10,
WriteTimeout: time.Millisecond * 10,
},
Validation: Validation{
Manifests: ValidationManifests{
@@ -197,19 +199,24 @@ notifications:
actions:
- pull
http:
clientcas:
- /path/to/ca.pem
tls:
clientcas:
- /path/to/ca.pem
headers:
X-Content-Type-Options: [nosniff]
redis:
addr: localhost:6379
tls:
certificate: /foo/cert.crt
key: /foo/key.pem
clientcas:
- /path/to/ca.pem
addrs: [localhost:6379]
username: alice
password: 123456
password: "123456"
db: 1
pool:
maxidle: 16
maxactive: 64
idletimeout: 300s
maxidleconns: 16
poolsize: 64
connmaxidletime: 300s
dialtimeout: 10ms
readtimeout: 10ms
writetimeout: 10ms
@@ -289,6 +296,7 @@ func (suite *ConfigSuite) TestParseSimple() {
func (suite *ConfigSuite) TestParseInmemory() {
suite.expectedConfig.Storage = Storage{"inmemory": Parameters{}}
suite.expectedConfig.Log.Fields = nil
suite.expectedConfig.HTTP.TLS.ClientCAs = nil
suite.expectedConfig.Redis = Redis{}
config, err := Parse(bytes.NewReader([]byte(inmemoryConfigYamlV0_1)))
@@ -309,6 +317,7 @@ func (suite *ConfigSuite) TestParseIncomplete() {
suite.expectedConfig.Auth = Auth{"silly": Parameters{"realm": "silly"}}
suite.expectedConfig.Notifications = Notifications{}
suite.expectedConfig.HTTP.Headers = nil
suite.expectedConfig.HTTP.TLS.ClientCAs = nil
suite.expectedConfig.Redis = Redis{}
suite.expectedConfig.Validation.Manifests.Indexes.Platforms = ""
@@ -579,8 +588,14 @@ func copyConfig(config Configuration) *Configuration {
for k, v := range config.HTTP.Headers {
configCopy.HTTP.Headers[k] = v
}
configCopy.HTTP.TLS.ClientCAs = make([]string, 0, len(config.HTTP.TLS.ClientCAs))
configCopy.HTTP.TLS.ClientCAs = append(configCopy.HTTP.TLS.ClientCAs, config.HTTP.TLS.ClientCAs...)
configCopy.Redis = config.Redis
configCopy.Redis.TLS.Certificate = config.Redis.TLS.Certificate
configCopy.Redis.TLS.Key = config.Redis.TLS.Key
configCopy.Redis.TLS.ClientCAs = make([]string, 0, len(config.Redis.TLS.ClientCAs))
configCopy.Redis.TLS.ClientCAs = append(configCopy.Redis.TLS.ClientCAs, config.Redis.TLS.ClientCAs...)
configCopy.Validation = Validation{
Enabled: config.Validation.Enabled,