mirror of
https://github.com/kairos-io/tpm-helpers.git
synced 2025-09-22 01:37:24 +00:00
Do an mdns lookup when the domain of the KMS is ending in .local
Part of: https://github.com/kairos-io/kairos/issues/2069 Signed-off-by: Dimitris Karakasilis <dimitris@karakasilis.me>
This commit is contained in:
68
get.go
68
get.go
@@ -8,14 +8,20 @@ import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-attestation/attest"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/hashicorp/mdns"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const MDNSTimeout = 15 * time.Second
|
||||
const MDNSServiceName = "_kcrypt._tcp"
|
||||
|
||||
// GetAuthToken generates an authentication token from the host TPM.
|
||||
// It will return the token as a string and the generated AK that should
|
||||
// be saved by the caller for later Authentication.
|
||||
@@ -129,6 +135,12 @@ func Connection(url string, opts ...Option) (*websocket.Conn, error) {
|
||||
header = http.Header{}
|
||||
}
|
||||
|
||||
var err error
|
||||
url, err = checkMDNSDomain(url, &header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialer := websocket.DefaultDialer
|
||||
if len(c.cacerts) > 0 {
|
||||
pool := x509.NewCertPool()
|
||||
@@ -237,3 +249,59 @@ func getChallengeResponse(c *config, ec *attest.EncryptedCredential, aikBytes []
|
||||
Secret: secret,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func checkMDNSDomain(originalURL string, headers *http.Header) (string, error) {
|
||||
parsedURL, err := url.Parse(originalURL)
|
||||
if err != nil {
|
||||
return originalURL, fmt.Errorf("parsing the mdns url: %w", err)
|
||||
}
|
||||
|
||||
host := parsedURL.Host
|
||||
if !strings.HasSuffix(host, ".local") {
|
||||
return originalURL, nil
|
||||
}
|
||||
|
||||
mdnsIP, mdnsPort := discoverMDNS(host)
|
||||
if mdnsIP == "" { // no reply
|
||||
return originalURL, nil
|
||||
}
|
||||
|
||||
headers.Add("Host", parsedURL.Host)
|
||||
newURL := strings.ReplaceAll(originalURL, host, mdnsIP)
|
||||
// Remove any port in the original url
|
||||
if port := parsedURL.Port(); port != "" {
|
||||
newURL = strings.ReplaceAll(newURL, port, "")
|
||||
}
|
||||
|
||||
// Add any possible port from the mdns response
|
||||
if mdnsPort != "" {
|
||||
newURL = strings.ReplaceAll(newURL, mdnsIP, fmt.Sprintf("%s:%s", mdnsIP, mdnsPort))
|
||||
}
|
||||
|
||||
return newURL, nil
|
||||
}
|
||||
|
||||
func discoverMDNS(host string) (string, string) {
|
||||
// Make a channel for results and start listening
|
||||
entriesCh := make(chan *mdns.ServiceEntry, 4)
|
||||
defer close(entriesCh)
|
||||
|
||||
// Start the lookup.
|
||||
// The channel is buffered so it doesn't block.
|
||||
// The Lookup here has its own timeout until it receives a response.
|
||||
// We use a select with a timeout to read because we don't know if we didn't
|
||||
// get the response yet or if there will be no response at all.
|
||||
mdns.Lookup(MDNSServiceName, entriesCh)
|
||||
|
||||
select {
|
||||
case entry := <-entriesCh:
|
||||
// TODO: For now we don't care what the actual "host" is set to. Any response
|
||||
// will do. Maybe in the future we can verify with that entry.Host matches host,
|
||||
// or something like that but it's not a security measure. Anyone could bring up
|
||||
// a server that advertises to be "_kcrypt._tcp" type of service as long as they
|
||||
// can connect to the same network.
|
||||
return entry.AddrV4.String(), strconv.Itoa(entry.Port) // TODO: v6?
|
||||
case <-time.After(MDNSTimeout):
|
||||
return "", ""
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user