diff --git a/cmd/tap.go b/cmd/tap.go index 60ed0c3ff..e74284784 100644 --- a/cmd/tap.go +++ b/cmd/tap.go @@ -56,7 +56,7 @@ func init() { tapCmd.Flags().String(configStructs.StorageLimitLabel, defaultTapConfig.StorageLimit, "Override the default storage limit (per node)") tapCmd.Flags().String(configStructs.StorageClassLabel, defaultTapConfig.StorageClass, "Override the default storage class of the PersistentVolumeClaim (per node)") tapCmd.Flags().Bool(configStructs.DryRunLabel, defaultTapConfig.DryRun, "Preview of all pods matching the regex, without tapping them") - tapCmd.Flags().StringP(configStructs.PcapLabel, "p", defaultTapConfig.Pcap, fmt.Sprintf("Capture from a PCAP snapshot of %s (.tar.gz) using your Docker Daemon instead of Kubernetes. TAR path, S3 URL (object or bucket)", misc.Software)) + tapCmd.Flags().StringP(configStructs.PcapLabel, "p", defaultTapConfig.Pcap, fmt.Sprintf("Capture from a PCAP snapshot of %s (.tar.gz) using your Docker Daemon instead of Kubernetes. TAR path from the file system or an S3 URL (object, folder or the bucket)", misc.Software)) tapCmd.Flags().Bool(configStructs.ServiceMeshLabel, defaultTapConfig.ServiceMesh, "Capture the encrypted traffic if the cluster is configured with a service mesh and with mTLS") tapCmd.Flags().Bool(configStructs.TlsLabel, defaultTapConfig.Tls, "Capture the traffic that's encrypted with OpenSSL or Go crypto/tls libraries") tapCmd.Flags().Bool(configStructs.IgnoreTaintedLabel, defaultTapConfig.IgnoreTainted, "Ignore tainted pods while running Worker DaemonSet") diff --git a/cmd/tapPcapRunner.go b/cmd/tapPcapRunner.go index bcbd8937f..60b2a94a0 100644 --- a/cmd/tapPcapRunner.go +++ b/cmd/tapPcapRunner.go @@ -306,7 +306,23 @@ func downloadTarFromS3(s3Url string) (tarPath string, err error) { return } - if key == "" { + var file *os.File + file, err = os.CreateTemp(os.TempDir(), filepath.Base(key)) + if err != nil { + return + } + defer file.Close() + + log.Info().Str("bucket", bucket).Str("key", key).Msg("Downloading from S3") + + downloader := manager.NewDownloader(client) + _, err = downloader.Download(context.TODO(), file, &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + }) + if err != nil { + log.Info().Err(err).Msg("S3 object is not found. Assuming URL is not a single object. Listing the objects in given folder or the bucket to download...") + var tempDirPath string tempDirPath, err = os.MkdirTemp(os.TempDir(), "kubeshark_*") if err != nil { @@ -314,25 +330,30 @@ func downloadTarFromS3(s3Url string) (tarPath string, err error) { } for _, object := range listObjectsOutput.Contents { - key = *object.Key - fullPath := filepath.Join(tempDirPath, key) + objectKey := *object.Key + if !strings.HasPrefix(objectKey, key) { + continue + } + + fullPath := filepath.Join(tempDirPath, objectKey) err = os.MkdirAll(filepath.Dir(fullPath), os.ModePerm) if err != nil { return } - var file *os.File - file, err = os.Create(fullPath) + var objectFile *os.File + objectFile, err = os.Create(fullPath) if err != nil { return } + defer objectFile.Close() - log.Info().Str("bucket", bucket).Str("key", key).Msg("Downloading from S3") + log.Info().Str("bucket", bucket).Str("key", objectKey).Msg("Downloading from S3") downloader := manager.NewDownloader(client) - _, err = downloader.Download(context.TODO(), file, &s3.GetObjectInput{ + _, err = downloader.Download(context.TODO(), objectFile, &s3.GetObjectInput{ Bucket: aws.String(bucket), - Key: aws.String(key), + Key: aws.String(objectKey), }) if err != nil { return @@ -340,27 +361,11 @@ func downloadTarFromS3(s3Url string) (tarPath string, err error) { } tarPath, err = tarDirectory(tempDirPath) - } else { - var file *os.File - file, err = os.CreateTemp(os.TempDir(), filepath.Base(key)) - if err != nil { - return - } - - log.Info().Str("bucket", bucket).Str("key", key).Msg("Downloading from S3") - - downloader := manager.NewDownloader(client) - _, err = downloader.Download(context.TODO(), file, &s3.GetObjectInput{ - Bucket: aws.String(bucket), - Key: aws.String(key), - }) - if err != nil { - return - } - - tarPath = file.Name() + return } + tarPath = file.Name() + return }