GetCertificate from external certificate sources (Managers) (#163)
This work made possible by Tailscale: https://tailscale.com - thank you to the Tailscale team! * Implement custom GetCertificate callback Useful if another entity is managing certificates and can provide its own dynamically during handshakes. * Refactor CustomGetCertificate into OnDemandConfig * Set certs to managed=true This is only sorta true, but it allows handshake-time maintenance of the certificates that are cached from CustomGetCertificate. Our background maintenance routine skips certs that are OnDemand so it should be fine. * Change CustomGetCertificate into interface value Instead of a function * Case-insensitive subject name comparison Hostnames are case-insensitive Also add context to GetCertificate * Export a couple of outrageously useful functions * Allow multiple custom certificate getters Also minor refactoring and enhancements * Fix tests * Rename Getter -> Manager; refactor And don't cache externally managed certs * Minor updates to comments
This commit is contained in:
parent
134f03986c
commit
797d29bcf3
@ -62,7 +62,7 @@ func (am *ACMEManager) loadAccount(ca, email string) (acme.Account, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return acct, err
|
return acct, err
|
||||||
}
|
}
|
||||||
acct.PrivateKey, err = decodePrivateKey(keyBytes)
|
acct.PrivateKey, err = PEMDecodePrivateKey(keyBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return acct, fmt.Errorf("could not decode account's private key: %v", err)
|
return acct, fmt.Errorf("could not decode account's private key: %v", err)
|
||||||
}
|
}
|
||||||
@ -129,7 +129,7 @@ func (am *ACMEManager) lookUpAccount(ctx context.Context, privateKeyPEM []byte)
|
|||||||
return acme.Account{}, fmt.Errorf("creating ACME client: %v", err)
|
return acme.Account{}, fmt.Errorf("creating ACME client: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
privateKey, err := decodePrivateKey([]byte(privateKeyPEM))
|
privateKey, err := PEMDecodePrivateKey([]byte(privateKeyPEM))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return acme.Account{}, fmt.Errorf("decoding private key: %v", err)
|
return acme.Account{}, fmt.Errorf("decoding private key: %v", err)
|
||||||
}
|
}
|
||||||
@ -157,7 +157,7 @@ func (am *ACMEManager) saveAccount(ca string, account acme.Account) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
keyBytes, err := encodePrivateKey(account.PrivateKey)
|
keyBytes, err := PEMEncodePrivateKey(account.PrivateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -57,6 +57,12 @@ type Certificate struct {
|
|||||||
issuerKey string
|
issuerKey string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Empty returns true if the certificate struct is not filled out; at
|
||||||
|
// least the tls.Certificate.Certificate field is expected to be set.
|
||||||
|
func (cert Certificate) Empty() bool {
|
||||||
|
return len(cert.Certificate.Certificate) == 0
|
||||||
|
}
|
||||||
|
|
||||||
// NeedsRenewal returns true if the certificate is
|
// NeedsRenewal returns true if the certificate is
|
||||||
// expiring soon (according to cfg) or has expired.
|
// expiring soon (according to cfg) or has expired.
|
||||||
func (cert Certificate) NeedsRenewal(cfg *Config) bool {
|
func (cert Certificate) NeedsRenewal(cfg *Config) bool {
|
||||||
@ -251,11 +257,15 @@ func fillCertFromLeaf(cert *Certificate, tlsCert tls.Certificate) error {
|
|||||||
// the leaf cert should be the one for the site; we must set
|
// the leaf cert should be the one for the site; we must set
|
||||||
// the tls.Certificate.Leaf field so that TLS handshakes are
|
// the tls.Certificate.Leaf field so that TLS handshakes are
|
||||||
// more efficient
|
// more efficient
|
||||||
leaf, err := x509.ParseCertificate(tlsCert.Certificate[0])
|
leaf := cert.Certificate.Leaf
|
||||||
if err != nil {
|
if leaf == nil {
|
||||||
return err
|
var err error
|
||||||
|
leaf, err = x509.ParseCertificate(tlsCert.Certificate[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cert.Certificate.Leaf = leaf
|
||||||
}
|
}
|
||||||
cert.Certificate.Leaf = leaf
|
|
||||||
|
|
||||||
// for convenience, we do want to assemble all the
|
// for convenience, we do want to assemble all the
|
||||||
// subjects on the certificate into one list
|
// subjects on the certificate into one list
|
||||||
@ -393,9 +403,10 @@ func SubjectIsInternal(subj string) bool {
|
|||||||
// states that IP addresses must match exactly, but this function
|
// states that IP addresses must match exactly, but this function
|
||||||
// does not attempt to distinguish IP addresses from internal or
|
// does not attempt to distinguish IP addresses from internal or
|
||||||
// external DNS names that happen to look like IP addresses.
|
// external DNS names that happen to look like IP addresses.
|
||||||
// It uses DNS wildcard matching logic.
|
// It uses DNS wildcard matching logic and is case-insensitive.
|
||||||
// https://tools.ietf.org/html/rfc2818#section-3.1
|
// https://tools.ietf.org/html/rfc2818#section-3.1
|
||||||
func MatchWildcard(subject, wildcard string) bool {
|
func MatchWildcard(subject, wildcard string) bool {
|
||||||
|
subject, wildcard = strings.ToLower(subject), strings.ToLower(wildcard)
|
||||||
if subject == wildcard {
|
if subject == wildcard {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -27,7 +27,7 @@ func TestUnexportedGetCertificate(t *testing.T) {
|
|||||||
cfg := &Config{certCache: certCache}
|
cfg := &Config{certCache: certCache}
|
||||||
|
|
||||||
// When cache is empty
|
// When cache is empty
|
||||||
if _, matched, defaulted := cfg.getCertificate(&tls.ClientHelloInfo{ServerName: "example.com"}); matched || defaulted {
|
if _, matched, defaulted := cfg.getCertificateFromCache(&tls.ClientHelloInfo{ServerName: "example.com"}); matched || defaulted {
|
||||||
t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted)
|
t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -35,19 +35,19 @@ func TestUnexportedGetCertificate(t *testing.T) {
|
|||||||
firstCert := Certificate{Names: []string{"example.com"}}
|
firstCert := Certificate{Names: []string{"example.com"}}
|
||||||
certCache.cache["0xdeadbeef"] = firstCert
|
certCache.cache["0xdeadbeef"] = firstCert
|
||||||
certCache.cacheIndex["example.com"] = []string{"0xdeadbeef"}
|
certCache.cacheIndex["example.com"] = []string{"0xdeadbeef"}
|
||||||
if cert, matched, defaulted := cfg.getCertificate(&tls.ClientHelloInfo{ServerName: "example.com"}); !matched || defaulted || cert.Names[0] != "example.com" {
|
if cert, matched, defaulted := cfg.getCertificateFromCache(&tls.ClientHelloInfo{ServerName: "example.com"}); !matched || defaulted || cert.Names[0] != "example.com" {
|
||||||
t.Errorf("Didn't get a cert for 'example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
|
t.Errorf("Didn't get a cert for 'example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
|
||||||
}
|
}
|
||||||
|
|
||||||
// When retrieving wildcard certificate
|
// When retrieving wildcard certificate
|
||||||
certCache.cache["0xb01dface"] = Certificate{Names: []string{"*.example.com"}}
|
certCache.cache["0xb01dface"] = Certificate{Names: []string{"*.example.com"}}
|
||||||
certCache.cacheIndex["*.example.com"] = []string{"0xb01dface"}
|
certCache.cacheIndex["*.example.com"] = []string{"0xb01dface"}
|
||||||
if cert, matched, defaulted := cfg.getCertificate(&tls.ClientHelloInfo{ServerName: "sub.example.com"}); !matched || defaulted || cert.Names[0] != "*.example.com" {
|
if cert, matched, defaulted := cfg.getCertificateFromCache(&tls.ClientHelloInfo{ServerName: "sub.example.com"}); !matched || defaulted || cert.Names[0] != "*.example.com" {
|
||||||
t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
|
t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
|
||||||
}
|
}
|
||||||
|
|
||||||
// When no certificate matches and SNI is provided, return no certificate (should be TLS alert)
|
// When no certificate matches and SNI is provided, return no certificate (should be TLS alert)
|
||||||
if cert, matched, defaulted := cfg.getCertificate(&tls.ClientHelloInfo{ServerName: "nomatch"}); matched || defaulted {
|
if cert, matched, defaulted := cfg.getCertificateFromCache(&tls.ClientHelloInfo{ServerName: "nomatch"}); matched || defaulted {
|
||||||
t.Errorf("Expected matched=false, defaulted=false; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
|
t.Errorf("Expected matched=false, defaulted=false; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -190,10 +190,14 @@ func TestMatchWildcard(t *testing.T) {
|
|||||||
expect bool
|
expect bool
|
||||||
}{
|
}{
|
||||||
{"hostname", "hostname", true},
|
{"hostname", "hostname", true},
|
||||||
|
{"HOSTNAME", "hostname", true},
|
||||||
|
{"hostname", "HOSTNAME", true},
|
||||||
{"foo.localhost", "foo.localhost", true},
|
{"foo.localhost", "foo.localhost", true},
|
||||||
{"foo.localhost", "bar.localhost", false},
|
{"foo.localhost", "bar.localhost", false},
|
||||||
{"foo.localhost", "*.localhost", true},
|
{"foo.localhost", "*.localhost", true},
|
||||||
{"bar.localhost", "*.localhost", true},
|
{"bar.localhost", "*.localhost", true},
|
||||||
|
{"FOO.LocalHost", "*.localhost", true},
|
||||||
|
{"Bar.localhost", "*.LOCALHOST", true},
|
||||||
{"foo.bar.localhost", "*.localhost", false},
|
{"foo.bar.localhost", "*.localhost", false},
|
||||||
{".localhost", "*.localhost", false},
|
{".localhost", "*.localhost", false},
|
||||||
{"foo.localhost", "foo.*", false},
|
{"foo.localhost", "foo.*", false},
|
||||||
|
12
certmagic.go
12
certmagic.go
@ -374,6 +374,18 @@ type Revoker interface {
|
|||||||
Revoke(ctx context.Context, cert CertificateResource, reason int) error
|
Revoke(ctx context.Context, cert CertificateResource, reason int) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CertificateManager is a type that manages certificates (keeps them renewed)
|
||||||
|
// such that we can get certificates during TLS handshakes to immediately serve
|
||||||
|
// to clients.
|
||||||
|
//
|
||||||
|
// TODO: This is an EXPERIMENTAL API. It is subject to change/removal.
|
||||||
|
type CertificateManager interface {
|
||||||
|
// GetCertificate returns the certificate to use to complete the handshake.
|
||||||
|
// Since this is called during every TLS handshake, it must be very fast and not block.
|
||||||
|
// Returning (nil, nil) is valid and is simply treated as a no-op.
|
||||||
|
GetCertificate(context.Context, *tls.ClientHelloInfo) (*tls.Certificate, error)
|
||||||
|
}
|
||||||
|
|
||||||
// KeyGenerator can generate a private key.
|
// KeyGenerator can generate a private key.
|
||||||
type KeyGenerator interface {
|
type KeyGenerator interface {
|
||||||
// GenerateKey generates a private key. The returned
|
// GenerateKey generates a private key. The returned
|
||||||
|
19
config.go
19
config.go
@ -72,12 +72,21 @@ type Config struct {
|
|||||||
// Adds the must staple TLS extension to the CSR.
|
// Adds the must staple TLS extension to the CSR.
|
||||||
MustStaple bool
|
MustStaple bool
|
||||||
|
|
||||||
// The source for getting new certificates; the
|
// Sources for getting new, managed certificates;
|
||||||
// default Issuer is ACMEManager. If multiple
|
// the default Issuer is ACMEManager. If multiple
|
||||||
// issuers are specified, they will be tried in
|
// issuers are specified, they will be tried in
|
||||||
// turn until one succeeds.
|
// turn until one succeeds.
|
||||||
Issuers []Issuer
|
Issuers []Issuer
|
||||||
|
|
||||||
|
// Sources for getting new, unmanaged certificates.
|
||||||
|
// They will be invoked only during TLS handshakes
|
||||||
|
// before on-demand certificate management occurs,
|
||||||
|
// for certificates that are not already loaded into
|
||||||
|
// the in-memory cache.
|
||||||
|
//
|
||||||
|
// TODO: EXPERIMENTAL: subject to change and/or removal.
|
||||||
|
Managers []CertificateManager
|
||||||
|
|
||||||
// The source of new private keys for certificates;
|
// The source of new private keys for certificates;
|
||||||
// the default KeySource is StandardKeyGenerator.
|
// the default KeySource is StandardKeyGenerator.
|
||||||
KeySource KeyGenerator
|
KeySource KeyGenerator
|
||||||
@ -499,7 +508,7 @@ func (cfg *Config) obtainCert(ctx context.Context, name string, interactive bool
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
privKeyPEM, err = encodePrivateKey(privKey)
|
privKeyPEM, err = PEMEncodePrivateKey(privKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -605,7 +614,7 @@ func (cfg *Config) reusePrivateKey(domain string) (privKey crypto.PrivateKey, pr
|
|||||||
}
|
}
|
||||||
|
|
||||||
// we loaded a private key; try decoding it so we can use it
|
// we loaded a private key; try decoding it so we can use it
|
||||||
privKey, err = decodePrivateKey(privKeyPEM)
|
privKey, err = PEMDecodePrivateKey(privKeyPEM)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
@ -722,7 +731,7 @@ func (cfg *Config) renewCert(ctx context.Context, name string, force, interactiv
|
|||||||
zap.Duration("remaining", timeLeft))
|
zap.Duration("remaining", timeLeft))
|
||||||
}
|
}
|
||||||
|
|
||||||
privateKey, err := decodePrivateKey(certRes.PrivateKeyPEM)
|
privateKey, err := PEMDecodePrivateKey(certRes.PrivateKeyPEM)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
16
crypto.go
16
crypto.go
@ -36,8 +36,10 @@ import (
|
|||||||
"golang.org/x/net/idna"
|
"golang.org/x/net/idna"
|
||||||
)
|
)
|
||||||
|
|
||||||
// encodePrivateKey marshals a EC or RSA private key into a PEM-encoded array of bytes.
|
// PEMEncodePrivateKey marshals a private key into a PEM-encoded block.
|
||||||
func encodePrivateKey(key crypto.PrivateKey) ([]byte, error) {
|
// The private key must be one of *ecdsa.PrivateKey, *rsa.PrivateKey, or
|
||||||
|
// *ed25519.PrivateKey.
|
||||||
|
func PEMEncodePrivateKey(key crypto.PrivateKey) ([]byte, error) {
|
||||||
var pemType string
|
var pemType string
|
||||||
var keyBytes []byte
|
var keyBytes []byte
|
||||||
switch key := key.(type) {
|
switch key := key.(type) {
|
||||||
@ -65,11 +67,13 @@ func encodePrivateKey(key crypto.PrivateKey) ([]byte, error) {
|
|||||||
return pem.EncodeToMemory(&pemKey), nil
|
return pem.EncodeToMemory(&pemKey), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// decodePrivateKey loads a PEM-encoded ECC/RSA private key from an array of bytes.
|
// PEMDecodePrivateKey loads a PEM-encoded ECC/RSA private key from an array of bytes.
|
||||||
// Borrowed from Go standard library, to handle various private key and PEM block types.
|
// Borrowed from Go standard library, to handle various private key and PEM block types.
|
||||||
// https://github.com/golang/go/blob/693748e9fa385f1e2c3b91ca9acbb6c0ad2d133d/src/crypto/tls/tls.go#L291-L308
|
func PEMDecodePrivateKey(keyPEMBytes []byte) (crypto.Signer, error) {
|
||||||
// https://github.com/golang/go/blob/693748e9fa385f1e2c3b91ca9acbb6c0ad2d133d/src/crypto/tls/tls.go#L238)
|
// Modified from original:
|
||||||
func decodePrivateKey(keyPEMBytes []byte) (crypto.Signer, error) {
|
// https://github.com/golang/go/blob/693748e9fa385f1e2c3b91ca9acbb6c0ad2d133d/src/crypto/tls/tls.go#L291-L308
|
||||||
|
// https://github.com/golang/go/blob/693748e9fa385f1e2c3b91ca9acbb6c0ad2d133d/src/crypto/tls/tls.go#L238
|
||||||
|
|
||||||
keyBlockDER, _ := pem.Decode(keyPEMBytes)
|
keyBlockDER, _ := pem.Decode(keyPEMBytes)
|
||||||
|
|
||||||
if keyBlockDER == nil {
|
if keyBlockDER == nil {
|
||||||
|
@ -33,19 +33,19 @@ func TestEncodeDecodeRSAPrivateKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// test save
|
// test save
|
||||||
savedBytes, err := encodePrivateKey(privateKey)
|
savedBytes, err := PEMEncodePrivateKey(privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("error saving private key:", err)
|
t.Fatal("error saving private key:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// test load
|
// test load
|
||||||
loadedKey, err := decodePrivateKey(savedBytes)
|
loadedKey, err := PEMDecodePrivateKey(savedBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("error loading private key:", err)
|
t.Error("error loading private key:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// test load (should fail)
|
// test load (should fail)
|
||||||
_, err = decodePrivateKey(savedBytes[2:])
|
_, err = PEMDecodePrivateKey(savedBytes[2:])
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("loading private key should have failed")
|
t.Error("loading private key should have failed")
|
||||||
}
|
}
|
||||||
@ -63,13 +63,13 @@ func TestSaveAndLoadECCPrivateKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// test save
|
// test save
|
||||||
savedBytes, err := encodePrivateKey(privateKey)
|
savedBytes, err := PEMEncodePrivateKey(privateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("error saving private key:", err)
|
t.Fatal("error saving private key:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// test load
|
// test load
|
||||||
loadedKey, err := decodePrivateKey(savedBytes)
|
loadedKey, err := PEMDecodePrivateKey(savedBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("error loading private key:", err)
|
t.Error("error loading private key:", err)
|
||||||
}
|
}
|
||||||
|
103
handshake.go
103
handshake.go
@ -29,11 +29,15 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// GetCertificate gets a certificate to satisfy clientHello. In getting
|
// GetCertificate gets a certificate to satisfy clientHello. In getting
|
||||||
// the certificate, it abides the rules and settings defined in the
|
// the certificate, it abides the rules and settings defined in the Config
|
||||||
// Config that matches clientHello.ServerName. It first checks the in-
|
// that matches clientHello.ServerName. It tries to get certificates in
|
||||||
// memory cache, then, if the config enables "OnDemand", it accesses
|
// this order:
|
||||||
// disk, then accesses the network if it must obtain a new certificate
|
//
|
||||||
// via ACME.
|
// 1. Exact match in the in-memory cache
|
||||||
|
// 2. Wildcard match in the in-memory cache
|
||||||
|
// 3. Managers (if any)
|
||||||
|
// 4. Storage (if on-demand is enabled)
|
||||||
|
// 5. Issuers (if on-demand is enabled)
|
||||||
//
|
//
|
||||||
// This method is safe for use as a tls.Config.GetCertificate callback.
|
// This method is safe for use as a tls.Config.GetCertificate callback.
|
||||||
func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
@ -71,7 +75,7 @@ func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certif
|
|||||||
return &cert.Certificate, err
|
return &cert.Certificate, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// getCertificate gets a certificate that matches name from the in-memory
|
// getCertificateFromCache gets a certificate that matches name from the in-memory
|
||||||
// cache, according to the lookup table associated with cfg. The lookup then
|
// cache, according to the lookup table associated with cfg. The lookup then
|
||||||
// points to a certificate in the Instance certificate cache.
|
// points to a certificate in the Instance certificate cache.
|
||||||
//
|
//
|
||||||
@ -87,7 +91,7 @@ func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certif
|
|||||||
// which is by the Go Authors.
|
// which is by the Go Authors.
|
||||||
//
|
//
|
||||||
// This function is safe for concurrent use.
|
// This function is safe for concurrent use.
|
||||||
func (cfg *Config) getCertificate(hello *tls.ClientHelloInfo) (cert Certificate, matched, defaulted bool) {
|
func (cfg *Config) getCertificateFromCache(hello *tls.ClientHelloInfo) (cert Certificate, matched, defaulted bool) {
|
||||||
name := normalizedName(hello.ServerName)
|
name := normalizedName(hello.ServerName)
|
||||||
|
|
||||||
if name == "" {
|
if name == "" {
|
||||||
@ -216,24 +220,25 @@ func DefaultCertificateSelector(hello *tls.ClientHelloInfo, choices []Certificat
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getCertDuringHandshake will get a certificate for hello. It first tries
|
// getCertDuringHandshake will get a certificate for hello. It first tries
|
||||||
// the in-memory cache. If no certificate for hello is in the cache, the
|
// the in-memory cache. If no exact certificate for hello is in the cache, the
|
||||||
// config most closely corresponding to hello will be loaded. If that config
|
// config most closely corresponding to hello (like a wildcard) will be loaded.
|
||||||
// allows it (OnDemand==true) and if loadIfNecessary == true, it goes to disk
|
// If none could be matched from the cache, it invokes the configured certificate
|
||||||
// to load it into the cache and serve it. If it's not on disk and if
|
// managers to get a certificate and uses the first one that returns a certificate.
|
||||||
// obtainIfNecessary == true, the certificate will be obtained from the CA,
|
// If no certificate managers return a value, and if the config allows it
|
||||||
// cached, and served. If obtainIfNecessary is true, then loadIfNecessary
|
// (OnDemand!=nil) and if loadIfNecessary == true, it goes to storage to load the
|
||||||
// must also be set to true. An error will be returned if and only if no
|
// cert into the cache and serve it. If it's not on disk and if
|
||||||
// certificate is available.
|
// obtainIfNecessary == true, the certificate will be obtained from the CA, cached,
|
||||||
|
// and served. If obtainIfNecessary == true, then loadIfNecessary must also be == true.
|
||||||
|
// An error will be returned if and only if no certificate is available.
|
||||||
//
|
//
|
||||||
// This function is safe for concurrent use.
|
// This function is safe for concurrent use.
|
||||||
func (cfg *Config) getCertDuringHandshake(hello *tls.ClientHelloInfo, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
|
func (cfg *Config) getCertDuringHandshake(hello *tls.ClientHelloInfo, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
|
||||||
log := loggerNamed(cfg.Logger, "handshake")
|
log := loggerNamed(cfg.Logger, "handshake")
|
||||||
|
|
||||||
// TODO: get a proper context... somehow...?
|
ctx := context.TODO() // TODO: get a proper context? from somewhere...
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// First check our in-memory cache to see if we've already loaded it
|
// First check our in-memory cache to see if we've already loaded it
|
||||||
cert, matched, defaulted := cfg.getCertificate(hello)
|
cert, matched, defaulted := cfg.getCertificateFromCache(hello)
|
||||||
if matched {
|
if matched {
|
||||||
if log != nil {
|
if log != nil {
|
||||||
log.Debug("matched certificate in cache",
|
log.Debug("matched certificate in cache",
|
||||||
@ -251,6 +256,16 @@ func (cfg *Config) getCertDuringHandshake(hello *tls.ClientHelloInfo, loadIfNece
|
|||||||
return cert, nil
|
return cert, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If an external CertificateManager is configured, try to get it from them.
|
||||||
|
// Only continue to use our own logic if it returns empty+nil.
|
||||||
|
externalCert, err := cfg.getCertFromAnyCertManager(ctx, hello, log)
|
||||||
|
if err != nil {
|
||||||
|
return Certificate{}, err
|
||||||
|
}
|
||||||
|
if !externalCert.Empty() {
|
||||||
|
return externalCert, nil
|
||||||
|
}
|
||||||
|
|
||||||
name := cfg.getNameFromClientHello(hello)
|
name := cfg.getNameFromClientHello(hello)
|
||||||
|
|
||||||
// We might be able to load or obtain a needed certificate. Load from
|
// We might be able to load or obtain a needed certificate. Load from
|
||||||
@ -627,9 +642,8 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien
|
|||||||
renewAndReload := func(ctx context.Context, cancel context.CancelFunc) (Certificate, error) {
|
renewAndReload := func(ctx context.Context, cancel context.CancelFunc) (Certificate, error) {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
// otherwise, renew with issuer, etc.
|
||||||
var newCert Certificate
|
var newCert Certificate
|
||||||
var err error
|
|
||||||
|
|
||||||
if revoked {
|
if revoked {
|
||||||
newCert, err = cfg.forceRenew(ctx, log, currentCert)
|
newCert, err = cfg.forceRenew(ctx, log, currentCert)
|
||||||
} else {
|
} else {
|
||||||
@ -680,6 +694,55 @@ func (cfg *Config) renewDynamicCertificate(ctx context.Context, hello *tls.Clien
|
|||||||
return renewAndReload(ctx, cancel)
|
return renewAndReload(ctx, cancel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getCertFromAnyCertManager gets a certificate from cfg's Managers. If there are no Managers defined, this is
|
||||||
|
// a no-op that returns empty values. Otherwise, it gets a certificate for hello from the first Manager that
|
||||||
|
// returns a certificate and no error.
|
||||||
|
func (cfg *Config) getCertFromAnyCertManager(ctx context.Context, hello *tls.ClientHelloInfo, log *zap.Logger) (Certificate, error) {
|
||||||
|
// fast path if nothing to do
|
||||||
|
if len(cfg.Managers) == 0 {
|
||||||
|
return Certificate{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var upstreamCert *tls.Certificate
|
||||||
|
|
||||||
|
// try all the GetCertificate methods on external managers; use first one that returns a certificate
|
||||||
|
for i, certManager := range cfg.Managers {
|
||||||
|
var err error
|
||||||
|
upstreamCert, err = certManager.GetCertificate(ctx, hello)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("getting certificate from external certificate manager",
|
||||||
|
zap.String("sni", hello.ServerName),
|
||||||
|
zap.Int("cert_manager", i),
|
||||||
|
zap.Error(err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if upstreamCert != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if upstreamCert == nil {
|
||||||
|
if log != nil {
|
||||||
|
log.Debug("all external certificate managers yielded no certificates and no errors", zap.String("sni", hello.ServerName))
|
||||||
|
}
|
||||||
|
return Certificate{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var cert Certificate
|
||||||
|
err := fillCertFromLeaf(&cert, *upstreamCert)
|
||||||
|
if err != nil {
|
||||||
|
return Certificate{}, fmt.Errorf("external certificate manager: %s: filling cert from leaf: %v", hello.ServerName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if log != nil {
|
||||||
|
log.Debug("using externally-managed certificate",
|
||||||
|
zap.String("sni", hello.ServerName),
|
||||||
|
zap.Strings("names", cert.Names),
|
||||||
|
zap.Time("expiration", cert.Leaf.NotAfter))
|
||||||
|
}
|
||||||
|
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
// getTLSALPNChallengeCert is to be called when the clientHello pertains to
|
// getTLSALPNChallengeCert is to be called when the clientHello pertains to
|
||||||
// a TLS-ALPN challenge and a certificate is required to solve it. This method gets
|
// a TLS-ALPN challenge and a certificate is required to solve it. This method gets
|
||||||
// the relevant challenge info and then returns the associated certificate (if any)
|
// the relevant challenge info and then returns the associated certificate (if any)
|
||||||
|
Loading…
Reference in New Issue
Block a user