package main import ( "fmt" "io/ioutil" "log" "net/http" "os" "path" "time" ) const ( userDataURL = "http://169.254.169.254/latest/user-data" metaDataURL = "http://169.254.169.254/latest/meta-data/" ) // ProviderAWS is the type implementing the Provider interface for AWS type ProviderAWS struct { } // NewAWS returns a new ProviderAWS func NewAWS() *ProviderAWS { return &ProviderAWS{} } func (p *ProviderAWS) String() string { return "AWS" } // Probe checks if we are running on AWS func (p *ProviderAWS) Probe() bool { // Getting the hostname should always work... _, err := awsGet(metaDataURL + "hostname") return (err == nil) } // Extract gets both the AWS specific and generic userdata func (p *ProviderAWS) Extract() ([]byte, error) { // Get host name. This must not fail hostname, err := awsGet(metaDataURL + "hostname") if err != nil { return nil, err } err = ioutil.WriteFile(path.Join(ConfigPath, Hostname), hostname, 0644) if err != nil { return nil, fmt.Errorf("AWS: Failed to write hostname: %s", err) } if err := p.handleSSH(); err != nil { log.Printf("AWS: Failed to get ssh data: %s", err) } // Generic userdata userData, err := awsGet(userDataURL) if err != nil { log.Printf("AWS: Failed to get user-data: %s", err) // This is not an error return nil, nil } return userData, nil } // awsGet requests and extracts the requested URL func awsGet(url string) ([]byte, error) { var client = &http.Client{ Timeout: time.Second * 2, } req, err := http.NewRequest("", url, nil) if err != nil { return nil, fmt.Errorf("AWS: http.NewRequest failed: %s", err) } resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("AWS: Could not contact metadata service: %s", err) } if resp.StatusCode != 200 { return nil, fmt.Errorf("AWS: Status not ok: %d", resp.StatusCode) } body, err := ioutil.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("AWS: Failed to read http response: %s", err) } return body, nil } // SSH keys: func (p *ProviderAWS) handleSSH() error { sshKeys, err := awsGet(metaDataURL + "public-keys/0/openssh-key") if err != nil { return fmt.Errorf("Failed to get sshKeys: %s", err) } if err := os.Mkdir(path.Join(ConfigPath, SSH), 0755); err != nil { return fmt.Errorf("Failed to create %s: %s", SSH, err) } err = ioutil.WriteFile(path.Join(ConfigPath, SSH, "authorized_keys"), sshKeys, 0600) if err != nil { return fmt.Errorf("Failed to write ssh keys: %s", err) } return nil }