diff --git a/test/images/agnhost/connect/connect.go b/test/images/agnhost/connect/connect.go index 3423549e2d0..fcb111dd276 100644 --- a/test/images/agnhost/connect/connect.go +++ b/test/images/agnhost/connect/connect.go @@ -47,12 +47,14 @@ var ( timeout time.Duration protocol string udpData string + sctpData string ) func init() { CmdConnect.Flags().DurationVar(&timeout, "timeout", time.Duration(0), "Maximum time before returning an error") CmdConnect.Flags().StringVar(&protocol, "protocol", "tcp", "The protocol to use to perform the connection, can be tcp, udp or sctp") CmdConnect.Flags().StringVar(&udpData, "udp-data", "hostname", "The UDP payload send to the server") + CmdConnect.Flags().StringVar(&sctpData, "sctp-data", "hostname", "The SCTP payload send to the server") } func main(cmd *cobra.Command, args []string) { @@ -63,7 +65,7 @@ func main(cmd *cobra.Command, args []string) { case "udp": connectUDP(dest, timeout, udpData) case "sctp": - connectSCTP(dest, timeout) + connectSCTP(dest, timeout, sctpData) default: fmt.Fprint(os.Stderr, "Unsupported protocol\n", protocol) os.Exit(1) @@ -103,7 +105,11 @@ func connectTCP(dest string, timeout time.Duration) { os.Exit(1) } -func connectSCTP(dest string, timeout time.Duration) { +func connectSCTP(dest string, timeout time.Duration, data string) { + var ( + buf = make([]byte, 1024) + conn *sctp.SCTPConn + ) addr, err := sctp.ResolveSCTPAddr("sctp", dest) if err != nil { fmt.Fprintf(os.Stderr, "DNS: %v\n", err) @@ -114,11 +120,24 @@ func connectSCTP(dest string, timeout time.Duration) { errCh := make(chan error) go func() { - conn, err := sctp.DialSCTP("sctp", nil, addr) - if err == nil { - conn.Close() + conn, err = sctp.DialSCTP("sctp", nil, addr) + if err != nil { + errCh <- err + return + } + defer func() { + errCh <- conn.Close() + }() + + if _, err = conn.Write([]byte(fmt.Sprintf("%s\n", data))); err != nil { + errCh <- err + return + } + + if _, err = conn.Read(buf); err != nil { + errCh <- err + return } - errCh <- err }() select {