From 1a2275d54c24cbc0f85979ae3b2815d3b81aae15 Mon Sep 17 00:00:00 2001 From: a Date: Sun, 4 Aug 2024 14:37:03 -0500 Subject: [PATCH] fs storage: Use temporary files when writing (#300) * fix: use an tmp file to flush new certs to disk * add readme --- filestorage.go | 24 ++++- filestorage_test.go | 78 ++++++++++++++++ internal/atomicfile/README | 11 +++ internal/atomicfile/file.go | 148 +++++++++++++++++++++++++++++++ internal/atomicfile/file_test.go | 77 ++++++++++++++++ internal/testutil/readme.md | 27 ++++++ internal/testutil/testutil.go | 103 +++++++++++++++++++++ 7 files changed, 466 insertions(+), 2 deletions(-) create mode 100644 filestorage_test.go create mode 100644 internal/atomicfile/README create mode 100644 internal/atomicfile/file.go create mode 100644 internal/atomicfile/file_test.go create mode 100644 internal/testutil/readme.md create mode 100644 internal/testutil/testutil.go diff --git a/filestorage.go b/filestorage.go index f6f1360..d3df9cf 100644 --- a/filestorage.go +++ b/filestorage.go @@ -27,6 +27,8 @@ import ( "path/filepath" "runtime" "time" + + "github.com/caddyserver/certmagic/internal/atomicfile" ) // FileStorage facilitates forming file paths derived from a root @@ -82,12 +84,30 @@ func (s *FileStorage) Store(_ context.Context, key string, value []byte) error { if err != nil { return err } - return os.WriteFile(filename, value, 0600) + fp, err := atomicfile.New(filename, 0o600) + if err != nil { + return err + } + _, err = fp.Write(value) + if err != nil { + // cancel the write + fp.Cancel() + return err + } + // close, thereby flushing the write + return fp.Close() } // Load retrieves the value at key. func (s *FileStorage) Load(_ context.Context, key string) ([]byte, error) { - return os.ReadFile(s.Filename(key)) + // i believe it's possible for the read call to error but still return bytes, in event of something like a shortread? + // therefore, i think it's appropriate to not return any bytes to avoid downstream users of the package erroniously believing that + // bytes read + error is a valid response (it should not be) + xs, err := os.ReadFile(s.Filename(key)) + if err != nil { + return nil, err + } + return xs, nil } // Delete deletes the value at key. diff --git a/filestorage_test.go b/filestorage_test.go new file mode 100644 index 0000000..eafeec9 --- /dev/null +++ b/filestorage_test.go @@ -0,0 +1,78 @@ +package certmagic_test + +import ( + "bytes" + "context" + "os" + "testing" + + "github.com/caddyserver/certmagic" + "github.com/caddyserver/certmagic/internal/testutil" +) + +func TestFileStorageStoreLoad(t *testing.T) { + ctx := context.Background() + tmpDir, err := os.MkdirTemp(os.TempDir(), "certmagic*") + testutil.RequireNoError(t, err, "allocating tmp dir") + defer os.RemoveAll(tmpDir) + s := &certmagic.FileStorage{ + Path: tmpDir, + } + err = s.Store(ctx, "foo", []byte("bar")) + testutil.RequireNoError(t, err) + dat, err := s.Load(ctx, "foo") + testutil.RequireNoError(t, err) + testutil.RequireEqualValues(t, dat, []byte("bar")) +} + +func TestFileStorageStoreLoadRace(t *testing.T) { + ctx := context.Background() + tmpDir, err := os.MkdirTemp(os.TempDir(), "certmagic*") + testutil.RequireNoError(t, err, "allocating tmp dir") + defer os.RemoveAll(tmpDir) + s := &certmagic.FileStorage{ + Path: tmpDir, + } + a := bytes.Repeat([]byte("a"), 4096*1024) + b := bytes.Repeat([]byte("b"), 4096*1024) + err = s.Store(ctx, "foo", a) + testutil.RequireNoError(t, err) + done := make(chan struct{}) + go func() { + err := s.Store(ctx, "foo", b) + testutil.RequireNoError(t, err) + close(done) + }() + dat, err := s.Load(ctx, "foo") + <-done + testutil.RequireNoError(t, err) + testutil.RequireEqualValues(t, 4096*1024, len(dat)) +} + +func TestFileStorageWriteLock(t *testing.T) { + ctx := context.Background() + tmpDir, err := os.MkdirTemp(os.TempDir(), "certmagic*") + testutil.RequireNoError(t, err, "allocating tmp dir") + defer os.RemoveAll(tmpDir) + s := &certmagic.FileStorage{ + Path: tmpDir, + } + // cctx is a cancelled ctx. so if we can't immediately get the lock, it will fail + cctx, cn := context.WithCancel(ctx) + cn() + // should success + err = s.Lock(cctx, "foo") + testutil.RequireNoError(t, err) + // should fail + err = s.Lock(cctx, "foo") + testutil.RequireError(t, err) + + err = s.Unlock(cctx, "foo") + testutil.RequireNoError(t, err) + // shouldn't fail + err = s.Lock(cctx, "foo") + testutil.RequireNoError(t, err) + + err = s.Unlock(cctx, "foo") + testutil.RequireNoError(t, err) +} diff --git a/internal/atomicfile/README b/internal/atomicfile/README new file mode 100644 index 0000000..17d04dd --- /dev/null +++ b/internal/atomicfile/README @@ -0,0 +1,11 @@ +# atomic file + + +this is copied from + +https://github.com/containerd/containerd/blob/main/pkg%2Fatomicfile%2Ffile.go + + +see + +https://github.com/caddyserver/certmagic/issues/296 diff --git a/internal/atomicfile/file.go b/internal/atomicfile/file.go new file mode 100644 index 0000000..7b870f7 --- /dev/null +++ b/internal/atomicfile/file.go @@ -0,0 +1,148 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +/* +Package atomicfile provides a mechanism (on Unix-like platforms) to present a consistent view of a file to separate +processes even while the file is being written. This is accomplished by writing a temporary file, syncing to disk, and +renaming over the destination file name. + +Partial/inconsistent reads can occur due to: + 1. A process attempting to read the file while it is being written to (both in the case of a new file with a + short/incomplete write or in the case of an existing, updated file where new bytes may be written at the beginning + but old bytes may still be present after). + 2. Concurrent goroutines leading to multiple active writers of the same file. + +The above mechanism explicitly protects against (1) as all writes are to a file with a temporary name. + +There is no explicit protection against multiple, concurrent goroutines attempting to write the same file. However, +atomically writing the file should mean only one writer will "win" and a consistent file will be visible. + +Note: atomicfile is partially implemented for Windows. The Windows codepath performs the same operations, however +Windows does not guarantee that a rename operation is atomic; a crash in the middle may leave the destination file +truncated rather than with the expected content. +*/ +package atomicfile + +import ( + "errors" + "fmt" + "io" + "os" + "path/filepath" + "sync" +) + +// File is an io.ReadWriteCloser that can also be Canceled if a change needs to be abandoned. +type File interface { + io.ReadWriteCloser + // Cancel abandons a change to a file. This can be called if a write fails or another error occurs. + Cancel() error +} + +// ErrClosed is returned if Read or Write are called on a closed File. +var ErrClosed = errors.New("file is closed") + +// New returns a new atomic file. On Unix-like platforms, the writer (an io.ReadWriteCloser) is backed by a temporary +// file placed into the same directory as the destination file (using filepath.Dir to split the directory from the +// name). On a call to Close the temporary file is synced to disk and renamed to its final name, hiding any previous +// file by the same name. +// +// Note: Take care to call Close and handle any errors that are returned. Errors returned from Close may indicate that +// the file was not written with its final name. +func New(name string, mode os.FileMode) (File, error) { + return newFile(name, mode) +} + +type atomicFile struct { + name string + f *os.File + closed bool + closedMu sync.RWMutex +} + +func newFile(name string, mode os.FileMode) (File, error) { + dir := filepath.Dir(name) + f, err := os.CreateTemp(dir, "") + if err != nil { + return nil, fmt.Errorf("failed to create temp file: %w", err) + } + if err := f.Chmod(mode); err != nil { + return nil, fmt.Errorf("failed to change temp file permissions: %w", err) + } + return &atomicFile{name: name, f: f}, nil +} + +func (a *atomicFile) Close() (err error) { + a.closedMu.Lock() + defer a.closedMu.Unlock() + + if a.closed { + return nil + } + a.closed = true + + defer func() { + if err != nil { + _ = os.Remove(a.f.Name()) // ignore errors + } + }() + // The order of operations here is: + // 1. sync + // 2. close + // 3. rename + // While the ordering of 2 and 3 is not important on Unix-like operating systems, Windows cannot rename an open + // file. By closing first, we allow the rename operation to succeed. + if err = a.f.Sync(); err != nil { + return fmt.Errorf("failed to sync temp file %q: %w", a.f.Name(), err) + } + if err = a.f.Close(); err != nil { + return fmt.Errorf("failed to close temp file %q: %w", a.f.Name(), err) + } + if err = os.Rename(a.f.Name(), a.name); err != nil { + return fmt.Errorf("failed to rename %q to %q: %w", a.f.Name(), a.name, err) + } + return nil +} + +func (a *atomicFile) Cancel() error { + a.closedMu.Lock() + defer a.closedMu.Unlock() + + if a.closed { + return nil + } + a.closed = true + _ = a.f.Close() // ignore error + return os.Remove(a.f.Name()) +} + +func (a *atomicFile) Read(p []byte) (n int, err error) { + a.closedMu.RLock() + defer a.closedMu.RUnlock() + if a.closed { + return 0, ErrClosed + } + return a.f.Read(p) +} + +func (a *atomicFile) Write(p []byte) (n int, err error) { + a.closedMu.RLock() + defer a.closedMu.RUnlock() + if a.closed { + return 0, ErrClosed + } + return a.f.Write(p) +} diff --git a/internal/atomicfile/file_test.go b/internal/atomicfile/file_test.go new file mode 100644 index 0000000..8a3162c --- /dev/null +++ b/internal/atomicfile/file_test.go @@ -0,0 +1,77 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package atomicfile_test + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/caddyserver/certmagic/internal/atomicfile" + "github.com/caddyserver/certmagic/internal/testutil" +) + +func TestFile(t *testing.T) { + const content = "this is some test content for a file" + dir := t.TempDir() + path := filepath.Join(dir, "test-file") + + f, err := atomicfile.New(path, 0o644) + testutil.RequireNoError(t, err, "failed to create file") + n, err := fmt.Fprint(f, content) + testutil.AssertNoError(t, err, "failed to write content") + testutil.AssertEqual(t, len(content), n, "written bytes should be equal") + err = f.Close() + testutil.RequireNoError(t, err, "failed to close file") + + actual, err := os.ReadFile(path) + testutil.AssertNoError(t, err, "failed to read file") + testutil.AssertEqual(t, content, string(actual)) +} + +func TestConcurrentWrites(t *testing.T) { + const content1 = "this is the first content of the file. there should be none other." + const content2 = "the second content of the file should win!" + dir := t.TempDir() + path := filepath.Join(dir, "test-file") + + file1, err := atomicfile.New(path, 0o600) + testutil.RequireNoError(t, err, "failed to create file1") + file2, err := atomicfile.New(path, 0o644) + testutil.RequireNoError(t, err, "failed to create file2") + + n, err := fmt.Fprint(file1, content1) + testutil.AssertNoError(t, err, "failed to write content1") + testutil.AssertEqual(t, len(content1), n, "written bytes should be equal") + + n, err = fmt.Fprint(file2, content2) + testutil.AssertNoError(t, err, "failed to write content2") + testutil.AssertEqual(t, len(content2), n, "written bytes should be equal") + + err = file1.Close() + testutil.RequireNoError(t, err, "failed to close file1") + actual, err := os.ReadFile(path) + testutil.AssertNoError(t, err, "failed to read file") + testutil.AssertEqual(t, content1, string(actual)) + + err = file2.Close() + testutil.RequireNoError(t, err, "failed to close file2") + actual, err = os.ReadFile(path) + testutil.AssertNoError(t, err, "failed to read file") + testutil.AssertEqual(t, content2, string(actual)) +} diff --git a/internal/testutil/readme.md b/internal/testutil/readme.md new file mode 100644 index 0000000..bfdacd0 --- /dev/null +++ b/internal/testutil/readme.md @@ -0,0 +1,27 @@ +# testutil + +some testing functions copied out of github.com/stretchr/testify, which were originally used by atomicfile + +``` +MIT License + +Copyright (c) 2012-2020 Mat Ryer, Tyler Bunnell and contributors. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +``` diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go new file mode 100644 index 0000000..b6e9574 --- /dev/null +++ b/internal/testutil/testutil.go @@ -0,0 +1,103 @@ +package testutil + +import ( + "bytes" + "fmt" + "reflect" + "strings" + "testing" +) + +func AssertNoError(t *testing.T, err error, msgAndArgs ...string) { + if err != nil { + Fail(t, fmt.Sprintf("Received unexpected error:\n%+v", err), msgAndArgs...) + } +} + +func AssertEqual(t *testing.T, a, b any, msgAndArgs ...string) { + if !ObjectsAreEqual(a, b) { + Fail(t, fmt.Sprintf("Not equal: \n"+ + "expected: %v\n"+ + "actual : %v", a, b), msgAndArgs...) + } +} +func RequireNoError(t *testing.T, err error, msgAndArgs ...string) { + if err != nil { + Failnow(t, fmt.Sprintf("Received unexpected error:\n%+v", err), msgAndArgs...) + } +} +func RequireError(t *testing.T, err error, msgAndArgs ...string) { + if err == nil { + Failnow(t, fmt.Sprintf("Received no error when expecting error:\n%+v", err), msgAndArgs...) + } +} + +func RequireEqual(t *testing.T, a, b any, msgAndArgs ...string) { + if !ObjectsAreEqual(a, b) { + Failnow(t, fmt.Sprintf("Not equal: \n"+ + "expected: %v\n"+ + "actual : %v", a, b), msgAndArgs...) + } +} +func RequireEqualValues(t *testing.T, a, b any, msgAndArgs ...string) { + if !ObjectsAreEqualValues(a, b) { + Failnow(t, fmt.Sprintf("Not equal: \n"+ + "expected: %v\n"+ + "actual : %v", a, b), msgAndArgs...) + } +} + +func Fail(t testing.TB, xs string, msgs ...string) { + var testName string + // Add test name if the Go version supports it + if n, ok := t.(interface { + Name() string + }); ok { + testName = n.Name() + } + + t.Errorf("error %s:%s\n%s\n", testName, xs, strings.Join(msgs, "")) +} + +func Failnow(t testing.TB, xs string, msgs ...string) { + Fail(t, xs, msgs...) + t.FailNow() +} + +func ObjectsAreEqual(expected, actual interface{}) bool { + if expected == nil || actual == nil { + return expected == actual + } + + exp, ok := expected.([]byte) + if !ok { + return reflect.DeepEqual(expected, actual) + } + + act, ok := actual.([]byte) + if !ok { + return false + } + if exp == nil || act == nil { + return exp == nil && act == nil + } + return bytes.Equal(exp, act) +} + +func ObjectsAreEqualValues(expected, actual interface{}) bool { + if ObjectsAreEqual(expected, actual) { + return true + } + + actualType := reflect.TypeOf(actual) + if actualType == nil { + return false + } + expectedValue := reflect.ValueOf(expected) + if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) { + // Attempt comparison after type conversion + return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual) + } + + return false +}