Use configured email to pin to specific account key in storage (#283)
* Use the `email` configuration in the ACME issuer to "pin" an account to a key When the issuer is configured with both an email and key material, these should match -- but that also means we can use the email information to predict the key-key, skipping the potentially expensive storage.List operation. * `continue` when we cannot load the private key for an account Not being able to load this might be caused by a storage problem, or it could have been something we did earlier. In either case we do not know whether this is the account we're looking for, and breaking out now will trigger expensive calls to the ACME server to lookup the account and then save that account again even though it was perfectly fine to begin with. * Add unit tests for the changed behaviors
This commit is contained in:
parent
f64401c80d
commit
27ab129028
22
account.go
22
account.go
@ -88,11 +88,18 @@ func (*ACMEIssuer) newAccount(email string) (acme.Account, error) {
|
||||
// If it does not exist in storage, it will be retrieved from the ACME server and added to storage.
|
||||
// The account must already exist; it does not create a new account.
|
||||
func (am *ACMEIssuer) GetAccount(ctx context.Context, privateKeyPEM []byte) (acme.Account, error) {
|
||||
account, err := am.loadAccountByKey(ctx, privateKeyPEM)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
account, err = am.lookUpAccount(ctx, privateKeyPEM)
|
||||
email := am.getEmail()
|
||||
if email == "" {
|
||||
if account, err := am.loadAccountByKey(ctx, privateKeyPEM); err == nil {
|
||||
return account, nil
|
||||
}
|
||||
} else {
|
||||
keyBytes, err := am.config.Storage.Load(ctx, am.storageKeyUserPrivateKey(am.CA, email))
|
||||
if err == nil && bytes.Equal(bytes.TrimSpace(keyBytes), bytes.TrimSpace(privateKeyPEM)) {
|
||||
return am.loadAccount(ctx, am.CA, email)
|
||||
}
|
||||
}
|
||||
return account, err
|
||||
return am.lookUpAccount(ctx, privateKeyPEM)
|
||||
}
|
||||
|
||||
// loadAccountByKey loads the account with the given private key from storage, if it exists.
|
||||
@ -107,9 +114,14 @@ func (am *ACMEIssuer) loadAccountByKey(ctx context.Context, privateKeyPEM []byte
|
||||
email := path.Base(accountFolderKey)
|
||||
keyBytes, err := am.config.Storage.Load(ctx, am.storageKeyUserPrivateKey(am.CA, email))
|
||||
if err != nil {
|
||||
return acme.Account{}, err
|
||||
// Try the next account: This one is missing its private key, if it turns out to be the one we're looking
|
||||
// for we will try to save it again after confirming with the ACME server.
|
||||
continue
|
||||
}
|
||||
if bytes.Equal(bytes.TrimSpace(keyBytes), bytes.TrimSpace(privateKeyPEM)) {
|
||||
// Found the account with the correct private key, try loading it. If this fails we we will follow
|
||||
// the same procedure as if the private key was not found and confirm with the ACME server before saving
|
||||
// it again.
|
||||
return am.loadAccount(ctx, am.CA, email)
|
||||
}
|
||||
}
|
||||
|
236
account_test.go
236
account_test.go
@ -17,6 +17,7 @@ package certmagic
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
@ -26,6 +27,131 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// memoryStorage is an in-memory storage implementation with known contents *and* fixed iteration order for List.
|
||||
type memoryStorage struct {
|
||||
contents []memoryStorageItem
|
||||
}
|
||||
|
||||
type memoryStorageItem struct {
|
||||
key string
|
||||
data []byte
|
||||
}
|
||||
|
||||
func (m *memoryStorage) lookup(_ context.Context, key string) *memoryStorageItem {
|
||||
for _, item := range m.contents {
|
||||
if item.key == key {
|
||||
return &item
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *memoryStorage) Delete(ctx context.Context, key string) error {
|
||||
for i, item := range m.contents {
|
||||
if item.key == key {
|
||||
m.contents = append(m.contents[:i], m.contents[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fs.ErrNotExist
|
||||
}
|
||||
func (m *memoryStorage) Store(ctx context.Context, key string, value []byte) error {
|
||||
m.contents = append(m.contents, memoryStorageItem{key: key, data: value})
|
||||
return nil
|
||||
}
|
||||
func (m *memoryStorage) Exists(ctx context.Context, key string) bool {
|
||||
return m.lookup(ctx, key) != nil
|
||||
}
|
||||
func (m *memoryStorage) List(ctx context.Context, path string, recursive bool) ([]string, error) {
|
||||
if recursive {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
result := []string{}
|
||||
nextitem:
|
||||
for _, item := range m.contents {
|
||||
if !strings.HasPrefix(item.key, path+"/") {
|
||||
continue
|
||||
}
|
||||
name := strings.TrimPrefix(item.key, path+"/")
|
||||
if i := strings.Index(name, "/"); i >= 0 {
|
||||
name = name[:i]
|
||||
}
|
||||
|
||||
for _, existing := range result {
|
||||
if existing == name {
|
||||
continue nextitem
|
||||
}
|
||||
}
|
||||
result = append(result, name)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (m *memoryStorage) Load(ctx context.Context, key string) ([]byte, error) {
|
||||
if item := m.lookup(ctx, key); item != nil {
|
||||
return item.data, nil
|
||||
}
|
||||
return nil, fs.ErrNotExist
|
||||
}
|
||||
func (m *memoryStorage) Stat(ctx context.Context, key string) (KeyInfo, error) {
|
||||
if item := m.lookup(ctx, key); item != nil {
|
||||
return KeyInfo{Key: key, Size: int64(len(item.data))}, nil
|
||||
}
|
||||
return KeyInfo{}, fs.ErrNotExist
|
||||
}
|
||||
func (m *memoryStorage) Lock(ctx context.Context, name string) error { panic("unimplemented") }
|
||||
func (m *memoryStorage) Unlock(ctx context.Context, name string) error { panic("unimplemented") }
|
||||
|
||||
var _ Storage = (*memoryStorage)(nil)
|
||||
|
||||
type recordingStorage struct {
|
||||
Storage
|
||||
calls []recordedCall
|
||||
}
|
||||
|
||||
func (r *recordingStorage) Delete(ctx context.Context, key string) error {
|
||||
r.record("Delete", key)
|
||||
return r.Storage.Delete(ctx, key)
|
||||
}
|
||||
func (r *recordingStorage) Exists(ctx context.Context, key string) bool {
|
||||
r.record("Exists", key)
|
||||
return r.Storage.Exists(ctx, key)
|
||||
}
|
||||
func (r *recordingStorage) List(ctx context.Context, path string, recursive bool) ([]string, error) {
|
||||
r.record("List", path, recursive)
|
||||
return r.Storage.List(ctx, path, recursive)
|
||||
}
|
||||
func (r *recordingStorage) Load(ctx context.Context, key string) ([]byte, error) {
|
||||
r.record("Load", key)
|
||||
return r.Storage.Load(ctx, key)
|
||||
}
|
||||
func (r *recordingStorage) Lock(ctx context.Context, name string) error {
|
||||
r.record("Lock", name)
|
||||
return r.Storage.Lock(ctx, name)
|
||||
}
|
||||
func (r *recordingStorage) Stat(ctx context.Context, key string) (KeyInfo, error) {
|
||||
r.record("Stat", key)
|
||||
return r.Storage.Stat(ctx, key)
|
||||
}
|
||||
func (r *recordingStorage) Store(ctx context.Context, key string, value []byte) error {
|
||||
r.record("Store", key)
|
||||
return r.Storage.Store(ctx, key, value)
|
||||
}
|
||||
func (r *recordingStorage) Unlock(ctx context.Context, name string) error {
|
||||
r.record("Unlock", name)
|
||||
return r.Storage.Unlock(ctx, name)
|
||||
}
|
||||
|
||||
type recordedCall struct {
|
||||
name string
|
||||
args []interface{}
|
||||
}
|
||||
|
||||
func (r *recordingStorage) record(name string, args ...interface{}) {
|
||||
r.calls = append(r.calls, recordedCall{name: name, args: args})
|
||||
}
|
||||
|
||||
var _ Storage = (*recordingStorage)(nil)
|
||||
|
||||
func TestNewAccount(t *testing.T) {
|
||||
am := &ACMEIssuer{CA: dummyCA, mu: new(sync.Mutex)}
|
||||
testConfig := &Config{
|
||||
@ -159,6 +285,116 @@ func TestGetAccountAlreadyExists(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountAlreadyExistsSkipsBroken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
am := &ACMEIssuer{CA: dummyCA, mu: new(sync.Mutex)}
|
||||
testConfig := &Config{
|
||||
Issuers: []Issuer{am},
|
||||
Storage: &memoryStorage{},
|
||||
Logger: defaultTestLogger,
|
||||
certCache: new(Cache),
|
||||
}
|
||||
am.config = testConfig
|
||||
|
||||
email := "me@foobar.com"
|
||||
|
||||
// Create a "corrupted" account
|
||||
am.config.Storage.Store(ctx, am.storageKeyUserReg(am.CA, "notmeatall@foobar.com"), []byte("this is not a valid account"))
|
||||
|
||||
// Create the actual account
|
||||
account, err := am.newAccount(email)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating account: %v", err)
|
||||
}
|
||||
err = am.saveAccount(ctx, am.CA, account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error saving account: %v", err)
|
||||
}
|
||||
|
||||
// Expect to load account from disk
|
||||
keyBytes, err := PEMEncodePrivateKey(account.PrivateKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Error encoding private key: %v", err)
|
||||
}
|
||||
|
||||
loadedAccount, err := am.GetAccount(ctx, keyBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting account: %v", err)
|
||||
}
|
||||
|
||||
// Assert keys are the same
|
||||
if !privateKeysSame(account.PrivateKey, loadedAccount.PrivateKey) {
|
||||
t.Error("Expected private key to be the same after loading, but it wasn't")
|
||||
}
|
||||
|
||||
// Assert emails are the same
|
||||
if !reflect.DeepEqual(account.Contact, loadedAccount.Contact) {
|
||||
t.Errorf("Expected contacts to be equal, but was '%s' before and '%s' after loading", account.Contact, loadedAccount.Contact)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountWithEmailAlreadyExists(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
am := &ACMEIssuer{CA: dummyCA, mu: new(sync.Mutex)}
|
||||
testConfig := &Config{
|
||||
Issuers: []Issuer{am},
|
||||
Storage: &recordingStorage{Storage: &memoryStorage{}},
|
||||
Logger: defaultTestLogger,
|
||||
certCache: new(Cache),
|
||||
}
|
||||
am.config = testConfig
|
||||
|
||||
email := "me@foobar.com"
|
||||
|
||||
// Set up test
|
||||
account, err := am.newAccount(email)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating account: %v", err)
|
||||
}
|
||||
err = am.saveAccount(ctx, am.CA, account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error saving account: %v", err)
|
||||
}
|
||||
|
||||
// Set the expected email:
|
||||
am.Email = email
|
||||
err = am.setEmail(ctx, true)
|
||||
if err != nil {
|
||||
t.Fatalf("setEmail error: %v", err)
|
||||
}
|
||||
|
||||
// Expect to load account from disk
|
||||
keyBytes, err := PEMEncodePrivateKey(account.PrivateKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Error encoding private key: %v", err)
|
||||
}
|
||||
|
||||
loadedAccount, err := am.GetAccount(ctx, keyBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting account: %v", err)
|
||||
}
|
||||
|
||||
// Assert keys are the same
|
||||
if !privateKeysSame(account.PrivateKey, loadedAccount.PrivateKey) {
|
||||
t.Error("Expected private key to be the same after loading, but it wasn't")
|
||||
}
|
||||
|
||||
// Assert emails are the same
|
||||
if !reflect.DeepEqual(account.Contact, loadedAccount.Contact) {
|
||||
t.Errorf("Expected contacts to be equal, but was '%s' before and '%s' after loading", account.Contact, loadedAccount.Contact)
|
||||
}
|
||||
|
||||
// Assert that this was found without listing all accounts
|
||||
rs := testConfig.Storage.(*recordingStorage)
|
||||
for _, call := range rs.calls {
|
||||
if call.name == "List" {
|
||||
t.Error("Unexpected List call")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetEmailFromPackageDefault(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user