GetCertificate from external certificate sources (Managers) (#163)

This work made possible by Tailscale: https://tailscale.com - thank you to the Tailscale team!

* Implement custom GetCertificate callback

Useful if another entity is managing certificates and can
provide its own dynamically during handshakes.

* Refactor CustomGetCertificate into OnDemandConfig

* Set certs to managed=true

This is only sorta true, but it allows handshake-time maintenance of the
certificates that are cached from CustomGetCertificate.

Our background maintenance routine skips certs that are OnDemand so it
should be fine.

* Change CustomGetCertificate into interface value

Instead of a function

* Case-insensitive subject name comparison

Hostnames are case-insensitive

Also add context to GetCertificate

* Export a couple of outrageously useful functions

* Allow multiple custom certificate getters

Also minor refactoring and enhancements

* Fix tests

* Rename Getter -> Manager; refactor

And don't cache externally managed certs

* Minor updates to comments
This commit is contained in:
Matt Holt 2022-02-17 14:37:50 -07:00 committed by GitHub
parent 134f03986c
commit 797d29bcf3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 151 additions and 48 deletions

View File

@ -62,7 +62,7 @@ func (am *ACMEManager) loadAccount(ca, email string) (acme.Account, error) {
if err != nil { if err != nil {
return acct, err return acct, err
} }
acct.PrivateKey, err = decodePrivateKey(keyBytes) acct.PrivateKey, err = PEMDecodePrivateKey(keyBytes)
if err != nil { if err != nil {
return acct, fmt.Errorf("could not decode account's private key: %v", err) return acct, fmt.Errorf("could not decode account's private key: %v", err)
} }
@ -129,7 +129,7 @@ func (am *ACMEManager) lookUpAccount(ctx context.Context, privateKeyPEM []byte)
return acme.Account{}, fmt.Errorf("creating ACME client: %v", err) return acme.Account{}, fmt.Errorf("creating ACME client: %v", err)
} }
privateKey, err := decodePrivateKey([]byte(privateKeyPEM)) privateKey, err := PEMDecodePrivateKey([]byte(privateKeyPEM))
if err != nil { if err != nil {
return acme.Account{}, fmt.Errorf("decoding private key: %v", err) return acme.Account{}, fmt.Errorf("decoding private key: %v", err)
} }
@ -157,7 +157,7 @@ func (am *ACMEManager) saveAccount(ca string, account acme.Account) error {
if err != nil { if err != nil {
return err return err
} }
keyBytes, err := encodePrivateKey(account.PrivateKey) keyBytes, err := PEMEncodePrivateKey(account.PrivateKey)
if err != nil { if err != nil {
return err return err
} }

View File

@ -57,6 +57,12 @@ type Certificate struct {
issuerKey string issuerKey string
} }
// Empty returns true if the certificate struct is not filled out; at
// least the tls.Certificate.Certificate field is expected to be set.
func (cert Certificate) Empty() bool {
return len(cert.Certificate.Certificate) == 0
}
// NeedsRenewal returns true if the certificate is // NeedsRenewal returns true if the certificate is
// expiring soon (according to cfg) or has expired. // expiring soon (according to cfg) or has expired.
func (cert Certificate) NeedsRenewal(cfg *Config) bool { func (cert Certificate) NeedsRenewal(cfg *Config) bool {
@ -251,11 +257,15 @@ func fillCertFromLeaf(cert *Certificate, tlsCert tls.Certificate) error {
// the leaf cert should be the one for the site; we must set // the leaf cert should be the one for the site; we must set
// the tls.Certificate.Leaf field so that TLS handshakes are // the tls.Certificate.Leaf field so that TLS handshakes are
// more efficient // more efficient
leaf, err := x509.ParseCertificate(tlsCert.Certificate[0]) leaf := cert.Certificate.Leaf
if err != nil { if leaf == nil {
return err var err error
leaf, err = x509.ParseCertificate(tlsCert.Certificate[0])
if err != nil {
return err
}
cert.Certificate.Leaf = leaf
} }
cert.Certificate.Leaf = leaf
// for convenience, we do want to assemble all the // for convenience, we do want to assemble all the
// subjects on the certificate into one list // subjects on the certificate into one list
@ -393,9 +403,10 @@ func SubjectIsInternal(subj string) bool {
// states that IP addresses must match exactly, but this function // states that IP addresses must match exactly, but this function
// does not attempt to distinguish IP addresses from internal or // does not attempt to distinguish IP addresses from internal or
// external DNS names that happen to look like IP addresses. // external DNS names that happen to look like IP addresses.
// It uses DNS wildcard matching logic. // It uses DNS wildcard matching logic and is case-insensitive.
// https://tools.ietf.org/html/rfc2818#section-3.1 // https://tools.ietf.org/html/rfc2818#section-3.1
func MatchWildcard(subject, wildcard string) bool { func MatchWildcard(subject, wildcard string) bool {
subject, wildcard = strings.ToLower(subject), strings.ToLower(wildcard)
if subject == wildcard { if subject == wildcard {
return true return true
} }

View File

@ -27,7 +27,7 @@ func TestUnexportedGetCertificate(t *testing.T) {
cfg := &Config{certCache: certCache} cfg := &Config{certCache: certCache}
// When cache is empty // When cache is empty
if _, matched, defaulted := cfg.getCertificate(&tls.ClientHelloInfo{ServerName: "example.com"}); matched || defaulted { if _, matched, defaulted := cfg.getCertificateFromCache(&tls.ClientHelloInfo{ServerName: "example.com"}); matched || defaulted {
t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted) t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted)
} }
@ -35,19 +35,19 @@ func TestUnexportedGetCertificate(t *testing.T) {
firstCert := Certificate{Names: []string{"example.com"}} firstCert := Certificate{Names: []string{"example.com"}}
certCache.cache["0xdeadbeef"] = firstCert certCache.cache["0xdeadbeef"] = firstCert
certCache.cacheIndex["example.com"] = []string{"0xdeadbeef"} certCache.cacheIndex["example.com"] = []string{"0xdeadbeef"}
if cert, matched, defaulted := cfg.getCertificate(&tls.ClientHelloInfo{ServerName: "example.com"}); !matched || defaulted || cert.Names[0] != "example.com" { if cert, matched, defaulted := cfg.getCertificateFromCache(&tls.ClientHelloInfo{ServerName: "example.com"}); !matched || defaulted || cert.Names[0] != "example.com" {
t.Errorf("Didn't get a cert for 'example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted) t.Errorf("Didn't get a cert for 'example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
} }
// When retrieving wildcard certificate // When retrieving wildcard certificate
certCache.cache["0xb01dface"] = Certificate{Names: []string{"*.example.com"}} certCache.cache["0xb01dface"] = Certificate{Names: []string{"*.example.com"}}
certCache.cacheIndex["*.example.com"] = []string{"0xb01dface"} certCache.cacheIndex["*.example.com"] = []string{"0xb01dface"}
if cert, matched, defaulted := cfg.getCertificate(&tls.ClientHelloInfo{ServerName: "sub.example.com"}); !matched || defaulted || cert.Names[0] != "*.example.com" { if cert, matched, defaulted := cfg.getCertificateFromCache(&tls.ClientHelloInfo{ServerName: "sub.example.com"}); !matched || defaulted || cert.Names[0] != "*.example.com" {
t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted) t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
} }
// When no certificate matches and SNI is provided, return no certificate (should be TLS alert) // When no certificate matches and SNI is provided, return no certificate (should be TLS alert)
if cert, matched, defaulted := cfg.getCertificate(&tls.ClientHelloInfo{ServerName: "nomatch"}); matched || defaulted { if cert, matched, defaulted := cfg.getCertificateFromCache(&tls.ClientHelloInfo{ServerName: "nomatch"}); matched || defaulted {
t.Errorf("Expected matched=false, defaulted=false; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert) t.Errorf("Expected matched=false, defaulted=false; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
} }
} }
@ -190,10 +190,14 @@ func TestMatchWildcard(t *testing.T) {
expect bool expect bool
}{ }{
{"hostname", "hostname", true}, {"hostname", "hostname", true},
{"HOSTNAME", "hostname", true},
{"hostname", "HOSTNAME", true},
{"foo.localhost", "foo.localhost", true}, {"foo.localhost", "foo.localhost", true},
{"foo.localhost", "bar.localhost", false}, {"foo.localhost", "bar.localhost", false},
{"foo.localhost", "*.localhost", true}, {"foo.localhost", "*.localhost", true},
{"bar.localhost", "*.localhost", true}, {"bar.localhost", "*.localhost", true},
{"FOO.LocalHost", "*.localhost", true},
{"Bar.localhost", "*.LOCALHOST", true},
{"foo.bar.localhost", "*.localhost", false}, {"foo.bar.localhost", "*.localhost", false},
{".localhost", "*.localhost", false}, {".localhost", "*.localhost", false},
{"foo.localhost", "foo.*", false}, {"foo.localhost", "foo.*", false},

View File

@ -374,6 +374,18 @@ type Revoker interface {
Revoke(ctx context.Context, cert CertificateResource, reason int) error Revoke(ctx context.Context, cert CertificateResource, reason int) error
} }
// CertificateManager is a type that manages certificates (keeps them renewed)
// such that we can get certificates during TLS handshakes to immediately serve
// to clients.
//
// TODO: This is an EXPERIMENTAL API. It is subject to change/removal.
type CertificateManager interface {
// GetCertificate returns the certificate to use to complete the handshake.
// Since this is called during every TLS handshake, it must be very fast and not block.
// Returning (nil, nil) is valid and is simply treated as a no-op.
GetCertificate(context.Context, *tls.ClientHelloInfo) (*tls.Certificate, error)
}
// KeyGenerator can generate a private key. // KeyGenerator can generate a private key.
type KeyGenerator interface { type KeyGenerator interface {
// GenerateKey generates a private key. The returned // GenerateKey generates a private key. The returned

View File

@ -72,12 +72,21 @@ type Config struct {
// Adds the must staple TLS extension to the CSR. // Adds the must staple TLS extension to the CSR.
MustStaple bool MustStaple bool
// The source for getting new certificates; the // Sources for getting new, managed certificates;
// default Issuer is ACMEManager. If multiple // the default Issuer is ACMEManager. If multiple
// issuers are specified, they will be tried in // issuers are specified, they will be tried in
// turn until one succeeds. // turn until one succeeds.
Issuers []Issuer Issuers []Issuer
// Sources for getting new, unmanaged certificates.
// They will be invoked only during TLS handshakes
// before on-demand certificate management occurs,
// for certificates that are not already loaded into
// the in-memory cache.
//
// TODO: EXPERIMENTAL: subject to change and/or removal.
Managers []CertificateManager
// The source of new private keys for certificates; // The source of new private keys for certificates;
// the default KeySource is StandardKeyGenerator. // the default KeySource is StandardKeyGenerator.
KeySource KeyGenerator KeySource KeyGenerator
@ -499,7 +508,7 @@ func (cfg *Config) obtainCert(ctx context.Context, name string, interactive bool
if err != nil { if err != nil {
return err return err
} }
privKeyPEM, err = encodePrivateKey(privKey) privKeyPEM, err = PEMEncodePrivateKey(privKey)
if err != nil { if err != nil {
return err return err
} }
@ -605,7 +614,7 @@ func (cfg *Config) reusePrivateKey(domain string) (privKey crypto.PrivateKey, pr
} }
// we loaded a private key; try decoding it so we can use it // we loaded a private key; try decoding it so we can use it
privKey, err = decodePrivateKey(privKeyPEM) privKey, err = PEMDecodePrivateKey(privKeyPEM)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@ -722,7 +731,7 @@ func (cfg *Config) renewCert(ctx context.Context, name string, force, interactiv
zap.Duration("remaining", timeLeft)) zap.Duration("remaining", timeLeft))
} }
privateKey, err := decodePrivateKey(certRes.PrivateKeyPEM) privateKey, err := PEMDecodePrivateKey(certRes.PrivateKeyPEM)
if err != nil { if err != nil {
return err return err
} }

View File

@ -36,8 +36,10 @@ import (
"golang.org/x/net/idna" "golang.org/x/net/idna"
) )
// encodePrivateKey marshals a EC or RSA private key into a PEM-encoded array of bytes. // PEMEncodePrivateKey marshals a private key into a PEM-encoded block.
func encodePrivateKey(key crypto.PrivateKey) ([]byte, error) { // The private key must be one of *ecdsa.PrivateKey, *rsa.PrivateKey, or
// *ed25519.PrivateKey.
func PEMEncodePrivateKey(key crypto.PrivateKey) ([]byte, error) {
var pemType string var pemType string
var keyBytes []byte var keyBytes []byte
switch key := key.(type) { switch key := key.(type) {
@ -65,11 +67,13 @@ func encodePrivateKey(key crypto.PrivateKey) ([]byte, error) {
return pem.EncodeToMemory(&pemKey), nil return pem.EncodeToMemory(&pemKey), nil
} }
// decodePrivateKey loads a PEM-encoded ECC/RSA private key from an array of bytes. // PEMDecodePrivateKey loads a PEM-encoded ECC/RSA private key from an array of bytes.
// Borrowed from Go standard library, to handle various private key and PEM block types. // Borrowed from Go standard library, to handle various private key and PEM block types.
// https://github.com/golang/go/blob/693748e9fa385f1e2c3b91ca9acbb6c0ad2d133d/src/crypto/tls/tls.go#L291-L308 func PEMDecodePrivateKey(keyPEMBytes []byte) (crypto.Signer, error) {
// https://github.com/golang/go/blob/693748e9fa385f1e2c3b91ca9acbb6c0ad2d133d/src/crypto/tls/tls.go#L238) // Modified from original:
func decodePrivateKey(keyPEMBytes []byte) (crypto.Signer, error) { // https://github.com/golang/go/blob/693748e9fa385f1e2c3b91ca9acbb6c0ad2d133d/src/crypto/tls/tls.go#L291-L308
// https://github.com/golang/go/blob/693748e9fa385f1e2c3b91ca9acbb6c0ad2d133d/src/crypto/tls/tls.go#L238
keyBlockDER, _ := pem.Decode(keyPEMBytes) keyBlockDER, _ := pem.Decode(keyPEMBytes)
if keyBlockDER == nil { if keyBlockDER == nil {

View File

@ -33,19 +33,19 @@ func TestEncodeDecodeRSAPrivateKey(t *testing.T) {
} }
// test save // test save
savedBytes, err := encodePrivateKey(privateKey) savedBytes, err := PEMEncodePrivateKey(privateKey)
if err != nil { if err != nil {
t.Fatal("error saving private key:", err) t.Fatal("error saving private key:", err)
} }
// test load // test load
loadedKey, err := decodePrivateKey(savedBytes) loadedKey, err := PEMDecodePrivateKey(savedBytes)
if err != nil { if err != nil {
t.Error("error loading private key:", err) t.Error("error loading private key:", err)
} }
// test load (should fail) // test load (should fail)
_, err = decodePrivateKey(savedBytes[2:]) _, err = PEMDecodePrivateKey(savedBytes[2:])
if err == nil { if err == nil {
t.Error("loading private key should have failed") t.Error("loading private key should have failed")
} }
@ -63,13 +63,13 @@ func TestSaveAndLoadECCPrivateKey(t *testing.T) {
} }
// test save // test save
savedBytes, err := encodePrivateKey(privateKey) savedBytes, err := PEMEncodePrivateKey(privateKey)
if err != nil { if err != nil {
t.Fatal("error saving private key:", err) t.Fatal("error saving private key:", err)
} }
// test load // test load
loadedKey, err := decodePrivateKey(savedBytes) loadedKey, err := PEMDecodePrivateKey(savedBytes)
if err != nil { if err != nil {
t.Error("error loading private key:", err) t.Error("error loading private key:", err)
} }

View File

@ -29,11 +29,15 @@ import (
) )
// GetCertificate gets a certificate to satisfy clientHello. In getting // GetCertificate gets a certificate to satisfy clientHello. In getting
// the certificate, it abides the rules and settings defined in the // the certificate, it abides the rules and settings defined in the Config
// Config that matches clientHello.ServerName. It first checks the in- // that matches clientHello.ServerName. It tries to get certificates in
// memory cache, then, if the config enables "OnDemand", it accesses // this order:
// disk, then accesses the network if it must obtain a new certificate //
// via ACME. // 1. Exact match in the in-memory cache
// 2. Wildcard match in the in-memory cache
// 3. Managers (if any)
// 4. Storage (if on-demand is enabled)
// 5. Issuers (if on-demand is enabled)
// //
// This method is safe for use as a tls.Config.GetCertificate callback. // This method is safe for use as a tls.Config.GetCertificate callback.
func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
@ -71,7 +75,7 @@ func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certif
return &cert.Certificate, err return &cert.Certificate, err
} }
// getCertificate gets a certificate that matches name from the in-memory // getCertificateFromCache gets a certificate that matches name from the in-memory
// cache, according to the lookup table associated with cfg. The lookup then // cache, according to the lookup table associated with cfg. The lookup then
// points to a certificate in the Instance certificate cache. // points to a certificate in the Instance certificate cache.
// //
@ -87,7 +91,7 @@ func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certif
// which is by the Go Authors. // which is by the Go Authors.
// //
// This function is safe for concurrent use. // This function is safe for concurrent use.
func (cfg *Config) getCertificate(hello *tls.ClientHelloInfo) (cert Certificate, matched, defaulted bool) { func (cfg *Config) getCertificateFromCache(hello *tls.ClientHelloInfo) (cert Certificate, matched, defaulted bool) {
name := normalizedName(hello.ServerName) name := normalizedName(hello.ServerName)
if name == "" { if name == "" {
@ -216,24 +220,25 @@ func DefaultCertificateSelector(hello *tls.ClientHelloInfo, choices []Certificat
} }
// getCertDuringHandshake will get a certificate for hello. It first tries // getCertDuringHandshake will get a certificate for hello. It first tries
// the in-memory cache. If no certificate for hello is in the cache, the // the in-memory cache. If no exact certificate for hello is in the cache, the
// config most closely corresponding to hello will be loaded. If that config // config most closely corresponding to hello (like a wildcard) will be loaded.
// allows it (OnDemand==true) and if loadIfNecessary == true, it goes to disk // If none could be matched from the cache, it invokes the configured certificate
// to load it into the cache and serve it. If it's not on disk and if // managers to get a certificate and uses the first one that returns a certificate.
// obtainIfNecessary == true, the certificate will be obtained from the CA, // If no certificate managers return a value, and if the config allows it
// cached, and served. If obtainIfNecessary is true, then loadIfNecessary // (OnDemand!=nil) and if loadIfNecessary == true, it goes to storage to load the
// must also be set to true. An error will be returned if and only if no // cert into the cache and serve it. If it's not on disk and if
// certificate is available. // obtainIfNecessary == true, the certificate will be obtained from the CA, cached,
// and served. If obtainIfNecessary == true, then loadIfNecessary must also be == true.
// An error will be returned if and only if no certificate is available.
// //
// This function is safe for concurrent use. // This function is safe for concurrent use.
func (cfg *Config) getCertDuringHandshake(hello *tls.ClientHelloInfo, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) { func (cfg *Config) getCertDuringHandshake(hello *tls.ClientHelloInfo, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
log := loggerNamed(cfg.Logger, "handshake") log := loggerNamed(cfg.Logger, "handshake")
// TODO: get a proper context... somehow...? ctx := context.TODO() // TODO: get a proper context? from somewhere...
ctx := context.Background()
// First check our in-memory cache to see if we've already loaded it // First check our in-memory cache to see if we've already loaded it
cert, matched, defaulted := cfg.getCertificate(hello) cert, matched, defaulted := cfg.getCertificateFromCache(hello)
if matched { if matched {
if log != nil { if log != nil {
log.Debug("matched certificate in cache", log.Debug("matched certificate in cache",
@ -251,6 +256,16 @@ func (cfg *Config) getCertDuringHandshake(hello *tls.ClientHelloInfo, loadIfNece
return cert, nil return cert, nil
} }
// If an external CertificateManager is configured, try to get it from them.
// Only continue to use our own logic if it returns empty+nil.
externalCert, err := cfg.getCertFromAnyCertManager(ctx, hello, log)
if err != nil {
return Certificate{}, err
}
if !externalCert.Empty() {
return externalCert, nil
}
name := cfg.getNameFromClientHello(hello) name := cfg.getNameFromClientHello(hello)
// We might be able to load or obtain a needed certificate. Load from // We might be able to load or obtain a needed certificate. Load from
@ -627,9 +642,8 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien
renewAndReload := func(ctx context.Context, cancel context.CancelFunc) (Certificate, error) { renewAndReload := func(ctx context.Context, cancel context.CancelFunc) (Certificate, error) {
defer cancel() defer cancel()
// otherwise, renew with issuer, etc.
var newCert Certificate var newCert Certificate
var err error
if revoked { if revoked {
newCert, err = cfg.forceRenew(ctx, log, currentCert) newCert, err = cfg.forceRenew(ctx, log, currentCert)
} else { } else {
@ -680,6 +694,55 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien
return renewAndReload(ctx, cancel) return renewAndReload(ctx, cancel)
} }
// getCertFromAnyCertManager gets a certificate from cfg's Managers. If there are no Managers defined, this is
// a no-op that returns empty values. Otherwise, it gets a certificate for hello from the first Manager that
// returns a certificate and no error.
func (cfg *Config) getCertFromAnyCertManager(ctx context.Context, hello *tls.ClientHelloInfo, log *zap.Logger) (Certificate, error) {
// fast path if nothing to do
if len(cfg.Managers) == 0 {
return Certificate{}, nil
}
var upstreamCert *tls.Certificate
// try all the GetCertificate methods on external managers; use first one that returns a certificate
for i, certManager := range cfg.Managers {
var err error
upstreamCert, err = certManager.GetCertificate(ctx, hello)
if err != nil {
log.Error("getting certificate from external certificate manager",
zap.String("sni", hello.ServerName),
zap.Int("cert_manager", i),
zap.Error(err))
continue
}
if upstreamCert != nil {
break
}
}
if upstreamCert == nil {
if log != nil {
log.Debug("all external certificate managers yielded no certificates and no errors", zap.String("sni", hello.ServerName))
}
return Certificate{}, nil
}
var cert Certificate
err := fillCertFromLeaf(&cert, *upstreamCert)
if err != nil {
return Certificate{}, fmt.Errorf("external certificate manager: %s: filling cert from leaf: %v", hello.ServerName, err)
}
if log != nil {
log.Debug("using externally-managed certificate",
zap.String("sni", hello.ServerName),
zap.Strings("names", cert.Names),
zap.Time("expiration", cert.Leaf.NotAfter))
}
return cert, nil
}
// getTLSALPNChallengeCert is to be called when the clientHello pertains to // getTLSALPNChallengeCert is to be called when the clientHello pertains to
// a TLS-ALPN challenge and a certificate is required to solve it. This method gets // a TLS-ALPN challenge and a certificate is required to solve it. This method gets
// the relevant challenge info and then returns the associated certificate (if any) // the relevant challenge info and then returns the associated certificate (if any)