diff --git a/agent/main.go b/agent/main.go index ef1ea32cc..055289202 100644 --- a/agent/main.go +++ b/agent/main.go @@ -115,6 +115,8 @@ func mergeUnique(slice []string, merge []string) []string { } func loadExtensions() { + appPorts := parseEnvVar(shared.AppPortsEnvVar) + dir, _ := filepath.Abs(filepath.Dir(os.Args[0])) extensionsDir := path.Join(dir, "./extensions/") @@ -140,6 +142,10 @@ func loadExtensions() { extension.Dissector = dissector log.Printf("Extension Properties: %+v\n", extension) extensions[i] = extension + if ports, ok := appPorts[extension.Protocol.Name]; ok { + log.Printf("Overriding \"%s\" extension's ports to: %v", extension.Protocol.Name, ports) + extension.Protocol.Ports = ports + } extensionsMap[extension.Protocol.Name] = extension allExtensionPorts = mergeUnique(allExtensionPorts, extension.Protocol.Ports) } @@ -186,13 +192,25 @@ func CORSMiddleware() gin.HandlerFunc { } } +func parseEnvVar(env string) map[string][]string { + var mapOfList map[string][]string + + val, present := os.LookupEnv(env) + + if !present { + return mapOfList + } + + err := json.Unmarshal([]byte(val), &mapOfList) + if err != nil { + panic(fmt.Sprintf("env var %s's value of %s is invalid! must be map[string][]string %v", env, mapOfList, err)) + } + return mapOfList +} + func getTapTargets() []string { nodeName := os.Getenv(shared.NodeNameEnvVar) - var tappedAddressesPerNodeDict map[string][]string - err := json.Unmarshal([]byte(os.Getenv(shared.TappedAddressesPerNodeDictEnvVar)), &tappedAddressesPerNodeDict) - if err != nil { - panic(fmt.Sprintf("env var %s's value of %s is invalid! must be map[string][]string %v", shared.TappedAddressesPerNodeDictEnvVar, tappedAddressesPerNodeDict, err)) - } + tappedAddressesPerNodeDict := parseEnvVar(shared.TappedAddressesPerNodeDictEnvVar) return tappedAddressesPerNodeDict[nodeName] } diff --git a/shared/consts.go b/shared/consts.go index 71cafdff2..07bcd0130 100644 --- a/shared/consts.go +++ b/shared/consts.go @@ -8,4 +8,5 @@ const ( MaxEntriesDBSizeBytesEnvVar = "MAX_ENTRIES_DB_BYTES" RulePolicyPath = "/app/enforce-policy/" RulePolicyFileName = "enforce-policy.yaml" + AppPortsEnvVar = "APP_PORTS" ) diff --git a/tap/passive_tapper.go b/tap/passive_tapper.go index 5e42cb9eb..170bc3867 100644 --- a/tap/passive_tapper.go +++ b/tap/passive_tapper.go @@ -18,7 +18,6 @@ import ( "os/signal" "runtime" "runtime/pprof" - "strconv" "strings" "sync" "time" @@ -39,19 +38,6 @@ const cleanPeriod = time.Second * 10 var remoteOnlyOutboundPorts = []int{80, 443} -func parseAppPorts(appPortsList string) []int { - ports := make([]int, 0) - for _, portStr := range strings.Split(appPortsList, ",") { - parsedInt, parseError := strconv.Atoi(portStr) - if parseError != nil { - log.Printf("Provided app port %v is not a valid number!", portStr) - } else { - ports = append(ports, parsedInt) - } - } - return ports -} - var maxcount = flag.Int64("c", -1, "Only grab this many packets, then exit") var decoder = flag.String("decoder", "", "Name of the decoder to use (default: guess from capture)") var statsevery = flag.Int("stats", 60, "Output statistics every N seconds") @@ -241,17 +227,7 @@ func startPassiveTapper(outputItems chan *api.OutputChannelItem, allExtensionPor ownIps = localhostIPs } - appPortsStr := os.Getenv(AppPortsEnvVar) - var appPorts []int - if appPortsStr == "" { - rlog.Info("Received empty/no APP_PORTS env var! only listening to ports:", allExtensionPorts) - appPorts = make([]int, 0) - } else { - appPorts = parseAppPorts(appPortsStr) - } - SetFilterPorts(appPorts) - - log.Printf("App Ports: %v", gSettings.filterPorts) + log.Printf("App Ports: %v", allExtensionPorts) var handle *pcap.Handle var err error diff --git a/tap/settings.go b/tap/settings.go index 7c8636239..96f12b2db 100644 --- a/tap/settings.go +++ b/tap/settings.go @@ -14,25 +14,13 @@ const ( ) type globalSettings struct { - filterPorts []int filterAuthorities []string } var gSettings = &globalSettings{ - filterPorts: []int{}, filterAuthorities: []string{}, } -func SetFilterPorts(ports []int) { - gSettings.filterPorts = ports -} - -func GetFilterPorts() []int { - ports := make([]int, len(gSettings.filterPorts)) - copy(ports, gSettings.filterPorts) - return ports -} - func SetFilterAuthorities(ipAddresses []string) { gSettings.filterAuthorities = ipAddresses } diff --git a/tap/tcp_stream_factory.go b/tap/tcp_stream_factory.go index 0b5e5ca26..887869da9 100644 --- a/tap/tcp_stream_factory.go +++ b/tap/tcp_stream_factory.go @@ -29,7 +29,7 @@ func (factory *tcpStreamFactory) New(net, transport gopacket.Flow, tcp *layers.T fsmOptions := reassembly.TCPSimpleFSMOptions{ SupportMissingEstablishment: *allowmissinginit, } - rlog.Debugf("Current App Ports: %v", gSettings.filterPorts) + rlog.Debugf("Current App Ports: %v", factory.AllExtensionPorts) srcIp := net.Src().String() dstIp := net.Dst().String() srcPort := transport.Src().String()