Use configured email to pin to specific account key in storage (#283)

* Use the `email` configuration in the ACME issuer to "pin" an account to a key

When the issuer is configured with both an email and key material, these should match -- but that also means we
can use the email information to predict the key-key, skipping the potentially expensive storage.List operation.

* `continue` when we cannot load the private key for an account

Not being able to load this might be caused by a storage problem, or it could have been something
we did earlier. In either case we do not know whether this is the account we're looking for, and breaking
out now will trigger expensive calls to the ACME server to lookup the account and then save that account
again even though it was perfectly fine to begin with.

* Add unit tests for the changed behaviors
This commit is contained in:
Andreas Kohn 2024-04-18 21:42:33 +02:00 committed by GitHub
parent f64401c80d
commit 27ab129028
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 253 additions and 5 deletions

View File

@ -88,11 +88,18 @@ func (*ACMEIssuer) newAccount(email string) (acme.Account, error) {
// If it does not exist in storage, it will be retrieved from the ACME server and added to storage.
// The account must already exist; it does not create a new account.
func (am *ACMEIssuer) GetAccount(ctx context.Context, privateKeyPEM []byte) (acme.Account, error) {
account, err := am.loadAccountByKey(ctx, privateKeyPEM)
if errors.Is(err, fs.ErrNotExist) {
account, err = am.lookUpAccount(ctx, privateKeyPEM)
email := am.getEmail()
if email == "" {
if account, err := am.loadAccountByKey(ctx, privateKeyPEM); err == nil {
return account, nil
}
} else {
keyBytes, err := am.config.Storage.Load(ctx, am.storageKeyUserPrivateKey(am.CA, email))
if err == nil && bytes.Equal(bytes.TrimSpace(keyBytes), bytes.TrimSpace(privateKeyPEM)) {
return am.loadAccount(ctx, am.CA, email)
}
}
return account, err
return am.lookUpAccount(ctx, privateKeyPEM)
}
// loadAccountByKey loads the account with the given private key from storage, if it exists.
@ -107,9 +114,14 @@ func (am *ACMEIssuer) loadAccountByKey(ctx context.Context, privateKeyPEM []byte
email := path.Base(accountFolderKey)
keyBytes, err := am.config.Storage.Load(ctx, am.storageKeyUserPrivateKey(am.CA, email))
if err != nil {
return acme.Account{}, err
// Try the next account: This one is missing its private key, if it turns out to be the one we're looking
// for we will try to save it again after confirming with the ACME server.
continue
}
if bytes.Equal(bytes.TrimSpace(keyBytes), bytes.TrimSpace(privateKeyPEM)) {
// Found the account with the correct private key, try loading it. If this fails we we will follow
// the same procedure as if the private key was not found and confirm with the ACME server before saving
// it again.
return am.loadAccount(ctx, am.CA, email)
}
}

View File

@ -17,6 +17,7 @@ package certmagic
import (
"bytes"
"context"
"io/fs"
"os"
"path/filepath"
"reflect"
@ -26,6 +27,131 @@ import (
"time"
)
// memoryStorage is an in-memory storage implementation with known contents *and* fixed iteration order for List.
type memoryStorage struct {
contents []memoryStorageItem
}
type memoryStorageItem struct {
key string
data []byte
}
func (m *memoryStorage) lookup(_ context.Context, key string) *memoryStorageItem {
for _, item := range m.contents {
if item.key == key {
return &item
}
}
return nil
}
func (m *memoryStorage) Delete(ctx context.Context, key string) error {
for i, item := range m.contents {
if item.key == key {
m.contents = append(m.contents[:i], m.contents[i+1:]...)
return nil
}
}
return fs.ErrNotExist
}
func (m *memoryStorage) Store(ctx context.Context, key string, value []byte) error {
m.contents = append(m.contents, memoryStorageItem{key: key, data: value})
return nil
}
func (m *memoryStorage) Exists(ctx context.Context, key string) bool {
return m.lookup(ctx, key) != nil
}
func (m *memoryStorage) List(ctx context.Context, path string, recursive bool) ([]string, error) {
if recursive {
panic("unimplemented")
}
result := []string{}
nextitem:
for _, item := range m.contents {
if !strings.HasPrefix(item.key, path+"/") {
continue
}
name := strings.TrimPrefix(item.key, path+"/")
if i := strings.Index(name, "/"); i >= 0 {
name = name[:i]
}
for _, existing := range result {
if existing == name {
continue nextitem
}
}
result = append(result, name)
}
return result, nil
}
func (m *memoryStorage) Load(ctx context.Context, key string) ([]byte, error) {
if item := m.lookup(ctx, key); item != nil {
return item.data, nil
}
return nil, fs.ErrNotExist
}
func (m *memoryStorage) Stat(ctx context.Context, key string) (KeyInfo, error) {
if item := m.lookup(ctx, key); item != nil {
return KeyInfo{Key: key, Size: int64(len(item.data))}, nil
}
return KeyInfo{}, fs.ErrNotExist
}
func (m *memoryStorage) Lock(ctx context.Context, name string) error { panic("unimplemented") }
func (m *memoryStorage) Unlock(ctx context.Context, name string) error { panic("unimplemented") }
var _ Storage = (*memoryStorage)(nil)
type recordingStorage struct {
Storage
calls []recordedCall
}
func (r *recordingStorage) Delete(ctx context.Context, key string) error {
r.record("Delete", key)
return r.Storage.Delete(ctx, key)
}
func (r *recordingStorage) Exists(ctx context.Context, key string) bool {
r.record("Exists", key)
return r.Storage.Exists(ctx, key)
}
func (r *recordingStorage) List(ctx context.Context, path string, recursive bool) ([]string, error) {
r.record("List", path, recursive)
return r.Storage.List(ctx, path, recursive)
}
func (r *recordingStorage) Load(ctx context.Context, key string) ([]byte, error) {
r.record("Load", key)
return r.Storage.Load(ctx, key)
}
func (r *recordingStorage) Lock(ctx context.Context, name string) error {
r.record("Lock", name)
return r.Storage.Lock(ctx, name)
}
func (r *recordingStorage) Stat(ctx context.Context, key string) (KeyInfo, error) {
r.record("Stat", key)
return r.Storage.Stat(ctx, key)
}
func (r *recordingStorage) Store(ctx context.Context, key string, value []byte) error {
r.record("Store", key)
return r.Storage.Store(ctx, key, value)
}
func (r *recordingStorage) Unlock(ctx context.Context, name string) error {
r.record("Unlock", name)
return r.Storage.Unlock(ctx, name)
}
type recordedCall struct {
name string
args []interface{}
}
func (r *recordingStorage) record(name string, args ...interface{}) {
r.calls = append(r.calls, recordedCall{name: name, args: args})
}
var _ Storage = (*recordingStorage)(nil)
func TestNewAccount(t *testing.T) {
am := &ACMEIssuer{CA: dummyCA, mu: new(sync.Mutex)}
testConfig := &Config{
@ -159,6 +285,116 @@ func TestGetAccountAlreadyExists(t *testing.T) {
}
}
func TestGetAccountAlreadyExistsSkipsBroken(t *testing.T) {
ctx := context.Background()
am := &ACMEIssuer{CA: dummyCA, mu: new(sync.Mutex)}
testConfig := &Config{
Issuers: []Issuer{am},
Storage: &memoryStorage{},
Logger: defaultTestLogger,
certCache: new(Cache),
}
am.config = testConfig
email := "me@foobar.com"
// Create a "corrupted" account
am.config.Storage.Store(ctx, am.storageKeyUserReg(am.CA, "notmeatall@foobar.com"), []byte("this is not a valid account"))
// Create the actual account
account, err := am.newAccount(email)
if err != nil {
t.Fatalf("Error creating account: %v", err)
}
err = am.saveAccount(ctx, am.CA, account)
if err != nil {
t.Fatalf("Error saving account: %v", err)
}
// Expect to load account from disk
keyBytes, err := PEMEncodePrivateKey(account.PrivateKey)
if err != nil {
t.Fatalf("Error encoding private key: %v", err)
}
loadedAccount, err := am.GetAccount(ctx, keyBytes)
if err != nil {
t.Fatalf("Error getting account: %v", err)
}
// Assert keys are the same
if !privateKeysSame(account.PrivateKey, loadedAccount.PrivateKey) {
t.Error("Expected private key to be the same after loading, but it wasn't")
}
// Assert emails are the same
if !reflect.DeepEqual(account.Contact, loadedAccount.Contact) {
t.Errorf("Expected contacts to be equal, but was '%s' before and '%s' after loading", account.Contact, loadedAccount.Contact)
}
}
func TestGetAccountWithEmailAlreadyExists(t *testing.T) {
ctx := context.Background()
am := &ACMEIssuer{CA: dummyCA, mu: new(sync.Mutex)}
testConfig := &Config{
Issuers: []Issuer{am},
Storage: &recordingStorage{Storage: &memoryStorage{}},
Logger: defaultTestLogger,
certCache: new(Cache),
}
am.config = testConfig
email := "me@foobar.com"
// Set up test
account, err := am.newAccount(email)
if err != nil {
t.Fatalf("Error creating account: %v", err)
}
err = am.saveAccount(ctx, am.CA, account)
if err != nil {
t.Fatalf("Error saving account: %v", err)
}
// Set the expected email:
am.Email = email
err = am.setEmail(ctx, true)
if err != nil {
t.Fatalf("setEmail error: %v", err)
}
// Expect to load account from disk
keyBytes, err := PEMEncodePrivateKey(account.PrivateKey)
if err != nil {
t.Fatalf("Error encoding private key: %v", err)
}
loadedAccount, err := am.GetAccount(ctx, keyBytes)
if err != nil {
t.Fatalf("Error getting account: %v", err)
}
// Assert keys are the same
if !privateKeysSame(account.PrivateKey, loadedAccount.PrivateKey) {
t.Error("Expected private key to be the same after loading, but it wasn't")
}
// Assert emails are the same
if !reflect.DeepEqual(account.Contact, loadedAccount.Contact) {
t.Errorf("Expected contacts to be equal, but was '%s' before and '%s' after loading", account.Contact, loadedAccount.Contact)
}
// Assert that this was found without listing all accounts
rs := testConfig.Storage.(*recordingStorage)
for _, call := range rs.calls {
if call.name == "List" {
t.Error("Unexpected List call")
}
}
}
func TestGetEmailFromPackageDefault(t *testing.T) {
ctx := context.Background()