diff --git a/dav/client/client.go b/dav/client/client.go index aa63811..0a015c4 100644 --- a/dav/client/client.go +++ b/dav/client/client.go @@ -52,6 +52,10 @@ func NewWithStorageClient(storageClient StorageClient) *DavBlobstore { func (d *DavBlobstore) Put(sourceFilePath string, dest string) error { slog.Info("uploading file to webdav", "source", sourceFilePath, "dest", dest) + if err := validateBlobID(dest); err != nil { + return err + } + source, err := os.Open(sourceFilePath) if err != nil { return fmt.Errorf("failed to open source file: %w", err) @@ -75,6 +79,10 @@ func (d *DavBlobstore) Put(sourceFilePath string, dest string) error { func (d *DavBlobstore) Get(source string, dest string) error { slog.Info("downloading file from webdav", "source", source, "dest", dest) + if err := validateBlobID(source); err != nil { + return err + } + destFile, err := os.Create(dest) if err != nil { return fmt.Errorf("failed to create destination file: %w", err) @@ -98,17 +106,25 @@ func (d *DavBlobstore) Get(source string, dest string) error { func (d *DavBlobstore) Delete(dest string) error { slog.Info("deleting file from webdav", "dest", dest) + if err := validateBlobID(dest); err != nil { + return err + } return d.storageClient.Delete(dest) } func (d *DavBlobstore) Exists(dest string) (bool, error) { slog.Info("checking if file exists on webdav", "dest", dest) + if err := validateBlobID(dest); err != nil { + return false, err + } return d.storageClient.Exists(dest) } func (d *DavBlobstore) Sign(dest string, action string, expiration time.Duration) (string, error) { slog.Info("signing url for webdav", "dest", dest, "action", action, "expiration", expiration) - + if err := validateBlobID(dest); err != nil { + return "", err + } action = strings.ToUpper(action) switch action { case "GET", "PUT": @@ -122,27 +138,41 @@ func (d *DavBlobstore) Sign(dest string, action string, expiration time.Duration } } -// DeleteRecursive is not yet implemented in this refactoring func (d *DavBlobstore) DeleteRecursive(prefix string) error { - return fmt.Errorf("DeleteRecursive not yet implemented") + slog.Info("deleting blobs recursively from webdav", "prefix", prefix) + return d.storageClient.DeleteRecursive(prefix) } -// List is not yet implemented in this refactoring func (d *DavBlobstore) List(prefix string) ([]string, error) { - return nil, fmt.Errorf("List not yet implemented") + slog.Info("listing blobs on webdav", "prefix", prefix) + if prefix != "" { + if err := validatePrefix(prefix); err != nil { + return nil, err + } + } + return d.storageClient.List(prefix) } -// Copy is not yet implemented in this refactoring func (d *DavBlobstore) Copy(srcBlob string, dstBlob string) error { - return fmt.Errorf("Copy not yet implemented") + slog.Info("copying blob on webdav", "src", srcBlob, "dst", dstBlob) + if err := validateBlobID(srcBlob); err != nil { + return fmt.Errorf("invalid source blob ID: %w", err) + } + if err := validateBlobID(dstBlob); err != nil { + return fmt.Errorf("invalid destination blob ID: %w", err) + } + return d.storageClient.Copy(srcBlob, dstBlob) } -// Properties is not yet implemented in this refactoring func (d *DavBlobstore) Properties(dest string) error { - return fmt.Errorf("Properties not yet implemented") + slog.Info("fetching blob properties from webdav", "dest", dest) + if err := validateBlobID(dest); err != nil { + return err + } + return d.storageClient.Properties(dest) } -// EnsureStorageExists is not yet implemented in this refactoring func (d *DavBlobstore) EnsureStorageExists() error { - return fmt.Errorf("EnsureStorageExists not yet implemented") + slog.Info("ensuring webdav storage root exists") + return d.storageClient.EnsureStorageExists() } diff --git a/dav/client/client_test.go b/dav/client/client_test.go index b55ff7f..23efc5b 100644 --- a/dav/client/client_test.go +++ b/dav/client/client_test.go @@ -5,6 +5,7 @@ import ( "io" "os" "strings" + "time" "github.com/cloudfoundry/storage-cli/dav/client" "github.com/cloudfoundry/storage-cli/dav/client/clientfakes" @@ -124,4 +125,193 @@ var _ = Describe("Client", func() { Expect(exists).To(BeFalse()) }) }) + + Context("Sign", func() { + var expiry = 100 * time.Second + + It("returns a signed URL for action 'get'", func() { + fakeStorageClient := &clientfakes.FakeStorageClient{} + fakeStorageClient.SignReturns("https://the-signed-url", nil) + + davBlobstore := client.NewWithStorageClient(fakeStorageClient) + url, err := davBlobstore.Sign("blob/path", "get", expiry) + + Expect(err).NotTo(HaveOccurred()) + Expect(url).To(Equal("https://the-signed-url")) + + Expect(fakeStorageClient.SignCallCount()).To(Equal(1)) + object, action, expiration := fakeStorageClient.SignArgsForCall(0) + Expect(object).To(Equal("blob/path")) + Expect(action).To(Equal("GET")) + Expect(expiration).To(Equal(expiry)) + }) + + It("returns a signed URL for action 'put'", func() { + fakeStorageClient := &clientfakes.FakeStorageClient{} + fakeStorageClient.SignReturns("https://the-signed-url", nil) + + davBlobstore := client.NewWithStorageClient(fakeStorageClient) + url, err := davBlobstore.Sign("blob/path", "put", expiry) + + Expect(err).NotTo(HaveOccurred()) + Expect(url).To(Equal("https://the-signed-url")) + + _, action, _ := fakeStorageClient.SignArgsForCall(0) + Expect(action).To(Equal("PUT")) + }) + + It("fails on unknown action without calling the storage client", func() { + fakeStorageClient := &clientfakes.FakeStorageClient{} + + davBlobstore := client.NewWithStorageClient(fakeStorageClient) + url, err := davBlobstore.Sign("blob/path", "unknown", expiry) + + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("action not implemented")) + Expect(url).To(Equal("")) + Expect(fakeStorageClient.SignCallCount()).To(Equal(0)) + }) + + It("propagates errors from the storage client", func() { + fakeStorageClient := &clientfakes.FakeStorageClient{} + fakeStorageClient.SignReturns("", fmt.Errorf("boom")) + + davBlobstore := client.NewWithStorageClient(fakeStorageClient) + url, err := davBlobstore.Sign("blob/path", "get", expiry) + + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("boom")) + Expect(url).To(Equal("")) + }) + }) + + Context("Copy", func() { + It("forwards source and destination to the storage client", func() { + fakeStorageClient := &clientfakes.FakeStorageClient{} + fakeStorageClient.CopyReturns(nil) + + davBlobstore := client.NewWithStorageClient(fakeStorageClient) + err := davBlobstore.Copy("src/blob", "dst/blob") + + Expect(err).NotTo(HaveOccurred()) + Expect(fakeStorageClient.CopyCallCount()).To(Equal(1)) + + src, dst := fakeStorageClient.CopyArgsForCall(0) + Expect(src).To(Equal("src/blob")) + Expect(dst).To(Equal("dst/blob")) + }) + + It("propagates errors from the storage client", func() { + fakeStorageClient := &clientfakes.FakeStorageClient{} + fakeStorageClient.CopyReturns(fmt.Errorf("copy failed")) + + davBlobstore := client.NewWithStorageClient(fakeStorageClient) + err := davBlobstore.Copy("src/blob", "dst/blob") + + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("copy failed")) + }) + }) + + Context("List", func() { + It("returns the blobs reported by the storage client", func() { + fakeStorageClient := &clientfakes.FakeStorageClient{} + fakeStorageClient.ListReturns([]string{"a/b/c", "a/b/d"}, nil) + + davBlobstore := client.NewWithStorageClient(fakeStorageClient) + blobs, err := davBlobstore.List("a/b") + + Expect(err).NotTo(HaveOccurred()) + Expect(blobs).To(ConsistOf("a/b/c", "a/b/d")) + + Expect(fakeStorageClient.ListCallCount()).To(Equal(1)) + Expect(fakeStorageClient.ListArgsForCall(0)).To(Equal("a/b")) + }) + + It("propagates errors from the storage client", func() { + fakeStorageClient := &clientfakes.FakeStorageClient{} + fakeStorageClient.ListReturns(nil, fmt.Errorf("list failed")) + + davBlobstore := client.NewWithStorageClient(fakeStorageClient) + blobs, err := davBlobstore.List("any/prefix") + + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("list failed")) + Expect(blobs).To(BeNil()) + }) + }) + + Context("DeleteRecursive", func() { + It("forwards the prefix to the storage client", func() { + fakeStorageClient := &clientfakes.FakeStorageClient{} + fakeStorageClient.DeleteRecursiveReturns(nil) + + davBlobstore := client.NewWithStorageClient(fakeStorageClient) + err := davBlobstore.DeleteRecursive("some/prefix") + + Expect(err).NotTo(HaveOccurred()) + Expect(fakeStorageClient.DeleteRecursiveCallCount()).To(Equal(1)) + Expect(fakeStorageClient.DeleteRecursiveArgsForCall(0)).To(Equal("some/prefix")) + }) + + It("propagates errors from the storage client", func() { + fakeStorageClient := &clientfakes.FakeStorageClient{} + fakeStorageClient.DeleteRecursiveReturns(fmt.Errorf("recursive delete failed")) + + davBlobstore := client.NewWithStorageClient(fakeStorageClient) + err := davBlobstore.DeleteRecursive("some/prefix") + + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("recursive delete failed")) + }) + }) + + Context("Properties", func() { + It("forwards the destination to the storage client", func() { + fakeStorageClient := &clientfakes.FakeStorageClient{} + fakeStorageClient.PropertiesReturns(nil) + + davBlobstore := client.NewWithStorageClient(fakeStorageClient) + err := davBlobstore.Properties("blob/path") + + Expect(err).NotTo(HaveOccurred()) + Expect(fakeStorageClient.PropertiesCallCount()).To(Equal(1)) + Expect(fakeStorageClient.PropertiesArgsForCall(0)).To(Equal("blob/path")) + }) + + It("propagates errors from the storage client", func() { + fakeStorageClient := &clientfakes.FakeStorageClient{} + fakeStorageClient.PropertiesReturns(fmt.Errorf("properties failed")) + + davBlobstore := client.NewWithStorageClient(fakeStorageClient) + err := davBlobstore.Properties("blob/path") + + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("properties failed")) + }) + }) + + Context("EnsureStorageExists", func() { + It("delegates to the storage client", func() { + fakeStorageClient := &clientfakes.FakeStorageClient{} + fakeStorageClient.EnsureStorageExistsReturns(nil) + + davBlobstore := client.NewWithStorageClient(fakeStorageClient) + err := davBlobstore.EnsureStorageExists() + + Expect(err).NotTo(HaveOccurred()) + Expect(fakeStorageClient.EnsureStorageExistsCallCount()).To(Equal(1)) + }) + + It("propagates errors from the storage client", func() { + fakeStorageClient := &clientfakes.FakeStorageClient{} + fakeStorageClient.EnsureStorageExistsReturns(fmt.Errorf("ensure failed")) + + davBlobstore := client.NewWithStorageClient(fakeStorageClient) + err := davBlobstore.EnsureStorageExists() + + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("ensure failed")) + }) + }) }) diff --git a/dav/client/clientfakes/fake_storage_client.go b/dav/client/clientfakes/fake_storage_client.go index 38b20b6..6d4cf55 100644 --- a/dav/client/clientfakes/fake_storage_client.go +++ b/dav/client/clientfakes/fake_storage_client.go @@ -10,6 +10,18 @@ import ( ) type FakeStorageClient struct { + CopyStub func(string, string) error + copyMutex sync.RWMutex + copyArgsForCall []struct { + arg1 string + arg2 string + } + copyReturns struct { + result1 error + } + copyReturnsOnCall map[int]struct { + result1 error + } DeleteStub func(string) error deleteMutex sync.RWMutex deleteArgsForCall []struct { @@ -21,6 +33,27 @@ type FakeStorageClient struct { deleteReturnsOnCall map[int]struct { result1 error } + DeleteRecursiveStub func(string) error + deleteRecursiveMutex sync.RWMutex + deleteRecursiveArgsForCall []struct { + arg1 string + } + deleteRecursiveReturns struct { + result1 error + } + deleteRecursiveReturnsOnCall map[int]struct { + result1 error + } + EnsureStorageExistsStub func() error + ensureStorageExistsMutex sync.RWMutex + ensureStorageExistsArgsForCall []struct { + } + ensureStorageExistsReturns struct { + result1 error + } + ensureStorageExistsReturnsOnCall map[int]struct { + result1 error + } ExistsStub func(string) (bool, error) existsMutex sync.RWMutex existsArgsForCall []struct { @@ -47,6 +80,30 @@ type FakeStorageClient struct { result1 io.ReadCloser result2 error } + ListStub func(string) ([]string, error) + listMutex sync.RWMutex + listArgsForCall []struct { + arg1 string + } + listReturns struct { + result1 []string + result2 error + } + listReturnsOnCall map[int]struct { + result1 []string + result2 error + } + PropertiesStub func(string) error + propertiesMutex sync.RWMutex + propertiesArgsForCall []struct { + arg1 string + } + propertiesReturns struct { + result1 error + } + propertiesReturnsOnCall map[int]struct { + result1 error + } PutStub func(string, io.ReadCloser, int64) error putMutex sync.RWMutex putArgsForCall []struct { @@ -79,6 +136,68 @@ type FakeStorageClient struct { invocationsMutex sync.RWMutex } +func (fake *FakeStorageClient) Copy(arg1 string, arg2 string) error { + fake.copyMutex.Lock() + ret, specificReturn := fake.copyReturnsOnCall[len(fake.copyArgsForCall)] + fake.copyArgsForCall = append(fake.copyArgsForCall, struct { + arg1 string + arg2 string + }{arg1, arg2}) + stub := fake.CopyStub + fakeReturns := fake.copyReturns + fake.recordInvocation("Copy", []interface{}{arg1, arg2}) + fake.copyMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeStorageClient) CopyCallCount() int { + fake.copyMutex.RLock() + defer fake.copyMutex.RUnlock() + return len(fake.copyArgsForCall) +} + +func (fake *FakeStorageClient) CopyCalls(stub func(string, string) error) { + fake.copyMutex.Lock() + defer fake.copyMutex.Unlock() + fake.CopyStub = stub +} + +func (fake *FakeStorageClient) CopyArgsForCall(i int) (string, string) { + fake.copyMutex.RLock() + defer fake.copyMutex.RUnlock() + argsForCall := fake.copyArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeStorageClient) CopyReturns(result1 error) { + fake.copyMutex.Lock() + defer fake.copyMutex.Unlock() + fake.CopyStub = nil + fake.copyReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeStorageClient) CopyReturnsOnCall(i int, result1 error) { + fake.copyMutex.Lock() + defer fake.copyMutex.Unlock() + fake.CopyStub = nil + if fake.copyReturnsOnCall == nil { + fake.copyReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.copyReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakeStorageClient) Delete(arg1 string) error { fake.deleteMutex.Lock() ret, specificReturn := fake.deleteReturnsOnCall[len(fake.deleteArgsForCall)] @@ -140,6 +259,120 @@ func (fake *FakeStorageClient) DeleteReturnsOnCall(i int, result1 error) { }{result1} } +func (fake *FakeStorageClient) DeleteRecursive(arg1 string) error { + fake.deleteRecursiveMutex.Lock() + ret, specificReturn := fake.deleteRecursiveReturnsOnCall[len(fake.deleteRecursiveArgsForCall)] + fake.deleteRecursiveArgsForCall = append(fake.deleteRecursiveArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.DeleteRecursiveStub + fakeReturns := fake.deleteRecursiveReturns + fake.recordInvocation("DeleteRecursive", []interface{}{arg1}) + fake.deleteRecursiveMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeStorageClient) DeleteRecursiveCallCount() int { + fake.deleteRecursiveMutex.RLock() + defer fake.deleteRecursiveMutex.RUnlock() + return len(fake.deleteRecursiveArgsForCall) +} + +func (fake *FakeStorageClient) DeleteRecursiveCalls(stub func(string) error) { + fake.deleteRecursiveMutex.Lock() + defer fake.deleteRecursiveMutex.Unlock() + fake.DeleteRecursiveStub = stub +} + +func (fake *FakeStorageClient) DeleteRecursiveArgsForCall(i int) string { + fake.deleteRecursiveMutex.RLock() + defer fake.deleteRecursiveMutex.RUnlock() + argsForCall := fake.deleteRecursiveArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeStorageClient) DeleteRecursiveReturns(result1 error) { + fake.deleteRecursiveMutex.Lock() + defer fake.deleteRecursiveMutex.Unlock() + fake.DeleteRecursiveStub = nil + fake.deleteRecursiveReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeStorageClient) DeleteRecursiveReturnsOnCall(i int, result1 error) { + fake.deleteRecursiveMutex.Lock() + defer fake.deleteRecursiveMutex.Unlock() + fake.DeleteRecursiveStub = nil + if fake.deleteRecursiveReturnsOnCall == nil { + fake.deleteRecursiveReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.deleteRecursiveReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeStorageClient) EnsureStorageExists() error { + fake.ensureStorageExistsMutex.Lock() + ret, specificReturn := fake.ensureStorageExistsReturnsOnCall[len(fake.ensureStorageExistsArgsForCall)] + fake.ensureStorageExistsArgsForCall = append(fake.ensureStorageExistsArgsForCall, struct { + }{}) + stub := fake.EnsureStorageExistsStub + fakeReturns := fake.ensureStorageExistsReturns + fake.recordInvocation("EnsureStorageExists", []interface{}{}) + fake.ensureStorageExistsMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeStorageClient) EnsureStorageExistsCallCount() int { + fake.ensureStorageExistsMutex.RLock() + defer fake.ensureStorageExistsMutex.RUnlock() + return len(fake.ensureStorageExistsArgsForCall) +} + +func (fake *FakeStorageClient) EnsureStorageExistsCalls(stub func() error) { + fake.ensureStorageExistsMutex.Lock() + defer fake.ensureStorageExistsMutex.Unlock() + fake.EnsureStorageExistsStub = stub +} + +func (fake *FakeStorageClient) EnsureStorageExistsReturns(result1 error) { + fake.ensureStorageExistsMutex.Lock() + defer fake.ensureStorageExistsMutex.Unlock() + fake.EnsureStorageExistsStub = nil + fake.ensureStorageExistsReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeStorageClient) EnsureStorageExistsReturnsOnCall(i int, result1 error) { + fake.ensureStorageExistsMutex.Lock() + defer fake.ensureStorageExistsMutex.Unlock() + fake.EnsureStorageExistsStub = nil + if fake.ensureStorageExistsReturnsOnCall == nil { + fake.ensureStorageExistsReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.ensureStorageExistsReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakeStorageClient) Exists(arg1 string) (bool, error) { fake.existsMutex.Lock() ret, specificReturn := fake.existsReturnsOnCall[len(fake.existsArgsForCall)] @@ -268,6 +501,131 @@ func (fake *FakeStorageClient) GetReturnsOnCall(i int, result1 io.ReadCloser, re }{result1, result2} } +func (fake *FakeStorageClient) List(arg1 string) ([]string, error) { + fake.listMutex.Lock() + ret, specificReturn := fake.listReturnsOnCall[len(fake.listArgsForCall)] + fake.listArgsForCall = append(fake.listArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.ListStub + fakeReturns := fake.listReturns + fake.recordInvocation("List", []interface{}{arg1}) + fake.listMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeStorageClient) ListCallCount() int { + fake.listMutex.RLock() + defer fake.listMutex.RUnlock() + return len(fake.listArgsForCall) +} + +func (fake *FakeStorageClient) ListCalls(stub func(string) ([]string, error)) { + fake.listMutex.Lock() + defer fake.listMutex.Unlock() + fake.ListStub = stub +} + +func (fake *FakeStorageClient) ListArgsForCall(i int) string { + fake.listMutex.RLock() + defer fake.listMutex.RUnlock() + argsForCall := fake.listArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeStorageClient) ListReturns(result1 []string, result2 error) { + fake.listMutex.Lock() + defer fake.listMutex.Unlock() + fake.ListStub = nil + fake.listReturns = struct { + result1 []string + result2 error + }{result1, result2} +} + +func (fake *FakeStorageClient) ListReturnsOnCall(i int, result1 []string, result2 error) { + fake.listMutex.Lock() + defer fake.listMutex.Unlock() + fake.ListStub = nil + if fake.listReturnsOnCall == nil { + fake.listReturnsOnCall = make(map[int]struct { + result1 []string + result2 error + }) + } + fake.listReturnsOnCall[i] = struct { + result1 []string + result2 error + }{result1, result2} +} + +func (fake *FakeStorageClient) Properties(arg1 string) error { + fake.propertiesMutex.Lock() + ret, specificReturn := fake.propertiesReturnsOnCall[len(fake.propertiesArgsForCall)] + fake.propertiesArgsForCall = append(fake.propertiesArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.PropertiesStub + fakeReturns := fake.propertiesReturns + fake.recordInvocation("Properties", []interface{}{arg1}) + fake.propertiesMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeStorageClient) PropertiesCallCount() int { + fake.propertiesMutex.RLock() + defer fake.propertiesMutex.RUnlock() + return len(fake.propertiesArgsForCall) +} + +func (fake *FakeStorageClient) PropertiesCalls(stub func(string) error) { + fake.propertiesMutex.Lock() + defer fake.propertiesMutex.Unlock() + fake.PropertiesStub = stub +} + +func (fake *FakeStorageClient) PropertiesArgsForCall(i int) string { + fake.propertiesMutex.RLock() + defer fake.propertiesMutex.RUnlock() + argsForCall := fake.propertiesArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeStorageClient) PropertiesReturns(result1 error) { + fake.propertiesMutex.Lock() + defer fake.propertiesMutex.Unlock() + fake.PropertiesStub = nil + fake.propertiesReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeStorageClient) PropertiesReturnsOnCall(i int, result1 error) { + fake.propertiesMutex.Lock() + defer fake.propertiesMutex.Unlock() + fake.PropertiesStub = nil + if fake.propertiesReturnsOnCall == nil { + fake.propertiesReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.propertiesReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakeStorageClient) Put(arg1 string, arg2 io.ReadCloser, arg3 int64) error { fake.putMutex.Lock() ret, specificReturn := fake.putReturnsOnCall[len(fake.putArgsForCall)] diff --git a/dav/client/helpers.go b/dav/client/helpers.go index c687d24..dc8e602 100644 --- a/dav/client/helpers.go +++ b/dav/client/helpers.go @@ -2,6 +2,8 @@ package client import ( "crypto/x509" + "fmt" + "strings" boshcrypto "github.com/cloudfoundry/bosh-utils/crypto" davconf "github.com/cloudfoundry/storage-cli/dav/config" @@ -19,3 +21,59 @@ func getCertPool(config davconf.Config) (*x509.CertPool, error) { return certPool, nil } + +// validateBlobID rejects blob IDs that could confuse path joining or enable +// path traversal: empty, leading/trailing slashes, double slashes, . or .. +// segments, and control characters. +func validateBlobID(blobID string) error { + if blobID == "" { + return fmt.Errorf("blob ID cannot be empty") + } + + if strings.HasPrefix(blobID, "/") || strings.HasSuffix(blobID, "/") { + return fmt.Errorf("blob ID cannot start or end with slash: %q", blobID) + } + + if strings.Contains(blobID, "//") { + return fmt.Errorf("blob ID cannot contain empty path segments (//): %q", blobID) + } + + for _, segment := range strings.Split(blobID, "/") { + if segment == "." || segment == ".." { + return fmt.Errorf("blob ID cannot contain path traversal segments (. or ..): %q", blobID) + } + } + + for _, r := range blobID { + if r < 32 || r == 127 { + return fmt.Errorf("blob ID cannot contain control characters: %q", blobID) + } + } + + return nil +} + +// validatePrefix is like validateBlobID but allows a trailing slash. +func validatePrefix(prefix string) error { + if strings.HasPrefix(prefix, "/") { + return fmt.Errorf("prefix cannot start with slash: %q", prefix) + } + + if strings.Contains(prefix, "//") { + return fmt.Errorf("prefix cannot contain empty path segments (//): %q", prefix) + } + + for _, segment := range strings.Split(strings.TrimSuffix(prefix, "/"), "/") { + if segment == "." || segment == ".." { + return fmt.Errorf("prefix cannot contain path traversal segments (. or ..): %q", prefix) + } + } + + for _, r := range prefix { + if r < 32 || r == 127 { + return fmt.Errorf("prefix cannot contain control characters: %q", prefix) + } + } + + return nil +} diff --git a/dav/client/storage_client.go b/dav/client/storage_client.go index 38ba8dc..f77d196 100644 --- a/dav/client/storage_client.go +++ b/dav/client/storage_client.go @@ -1,8 +1,11 @@ package client import ( + "encoding/json" + "encoding/xml" "fmt" "io" + "log/slog" "net/http" "net/url" "path" @@ -21,7 +24,70 @@ type StorageClient interface { Put(path string, content io.ReadCloser, contentLength int64) (err error) Exists(path string) (bool, error) Delete(path string) (err error) + DeleteRecursive(prefix string) error Sign(objectID, action string, duration time.Duration) (string, error) + Copy(srcBlob, dstBlob string) error + List(prefix string) ([]string, error) + Properties(path string) error + EnsureStorageExists() error +} + +type BlobProperties struct { + ETag string `json:"etag,omitempty"` + LastModified time.Time `json:"last_modified,omitempty"` + ContentLength *int64 `json:"content_length,omitempty"` +} + +// PROPFIND request body — sent as XML to ask the WebDAV server for the +// resourcetype of every child entry of a collection. +type propfindRequest struct { + XMLName xml.Name `xml:"D:propfind"` + DAVNS string `xml:"xmlns:D,attr"` + Prop propfindReqProp `xml:"D:prop"` +} + +type propfindReqProp struct { + ResourceType struct{} `xml:"D:resourcetype"` +} + +func newPropfindBody() (io.Reader, error) { + body := propfindRequest{DAVNS: "DAV:"} + out, err := xml.MarshalIndent(body, "", " ") + if err != nil { + return nil, fmt.Errorf("marshaling PROPFIND body: %w", err) + } + return strings.NewReader(xml.Header + string(out)), nil +} + +type multistatusResponse struct { + XMLName xml.Name `xml:"multistatus"` + Responses []davResponse `xml:"response"` +} + +type davResponse struct { + Href string `xml:"href"` + PropStats []davPropStat `xml:"propstat"` +} + +type davPropStat struct { + Prop davProp `xml:"prop"` +} + +type davProp struct { + ResourceType davResourceType `xml:"resourcetype"` +} + +type davResourceType struct { + Collection *struct{} `xml:"collection"` +} + +func (r davResponse) isCollection() bool { + for _, ps := range r.PropStats { + if ps.Prop.ResourceType.Collection != nil { + return true + } + } + return false } type storageClient struct { @@ -135,22 +201,29 @@ func (c *storageClient) Sign(blobID, action string, duration time.Duration) (str return signedURL, nil } -func (c *storageClient) createReq(method, blobID string, body io.Reader) (*http.Request, error) { +func (c *storageClient) buildBlobURL(blobID string) (string, error) { blobURL, err := url.Parse(c.config.Endpoint) if err != nil { - return nil, err + return "", err } newPath := path.Join(blobURL.Path, blobID) if !strings.HasPrefix(newPath, "/") { newPath = "/" + newPath } - blobURL.Path = newPath + return blobURL.String(), nil +} + +func (c *storageClient) createReq(method, blobID string, body io.Reader) (*http.Request, error) { + rawURL, err := c.buildBlobURL(blobID) + if err != nil { + return nil, err + } - req, err := http.NewRequest(method, blobURL.String(), body) + req, err := http.NewRequest(method, rawURL, body) if err != nil { - return req, err + return nil, err } if c.config.User != "" { @@ -166,3 +239,240 @@ func (c *storageClient) readAndTruncateBody(resp *http.Response) string { bodyBytes, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) //nolint:errcheck return string(bodyBytes) } + +func (c *storageClient) Copy(srcBlob, dstBlob string) error { + dstURL, err := c.buildBlobURL(dstBlob) + if err != nil { + return fmt.Errorf("building destination URL: %w", err) + } + + // PUT an empty file first so nginx (create_full_put_path on) creates any + // missing parent directories before COPY overwrites the placeholder. + putReq, err := c.createReq("PUT", dstBlob, http.NoBody) + if err != nil { + return fmt.Errorf("creating destination PUT request: %w", err) + } + putReq.ContentLength = 0 + + putResp, err := c.httpClient.Do(putReq) + if err != nil { + return fmt.Errorf("creating destination placeholder: %w", err) + } + defer putResp.Body.Close() //nolint:errcheck + + if putResp.StatusCode != http.StatusCreated && putResp.StatusCode != http.StatusNoContent && putResp.StatusCode != http.StatusOK { + return fmt.Errorf("creating destination placeholder %q: status %d, body: %s", + dstBlob, putResp.StatusCode, c.readAndTruncateBody(putResp)) + } + + copyReq, err := c.createReq("COPY", srcBlob, nil) + if err != nil { + return fmt.Errorf("creating COPY request: %w", err) + } + copyReq.Header.Set("Destination", dstURL) + copyReq.Header.Set("Overwrite", "T") + + copyResp, err := c.httpClient.Do(copyReq) + if err != nil { + return fmt.Errorf("performing COPY %q -> %q: %w", srcBlob, dstBlob, err) + } + defer copyResp.Body.Close() //nolint:errcheck + + // RFC 4918 §9.8: 201 Created (new) or 204 No Content (overwritten). + if copyResp.StatusCode == http.StatusCreated || copyResp.StatusCode == http.StatusNoContent { + return nil + } + + return fmt.Errorf("COPY %q -> %q: status %d, body: %s", + srcBlob, dstBlob, copyResp.StatusCode, c.readAndTruncateBody(copyResp)) +} + +func (c *storageClient) List(prefix string) ([]string, error) { + rootURL, err := url.Parse(c.config.Endpoint) + if err != nil { + return nil, fmt.Errorf("parsing endpoint URL: %w", err) + } + if !strings.HasPrefix(rootURL.Path, "/") { + rootURL.Path = "/" + rootURL.Path + } + + return c.listRecursive(rootURL.String(), rootURL.Path, prefix) +} + +func (c *storageClient) listRecursive(dirURL, endpointPath, prefix string) ([]string, error) { + body, err := newPropfindBody() + if err != nil { + return nil, err + } + + req, err := http.NewRequest("PROPFIND", dirURL, body) + if err != nil { + return nil, fmt.Errorf("creating PROPFIND request: %w", err) + } + if c.config.User != "" { + req.SetBasicAuth(c.config.User, c.config.Password) + } + req.Header.Set("Depth", "1") + req.Header.Set("Content-Type", "application/xml") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("performing PROPFIND: %w", err) + } + defer resp.Body.Close() //nolint:errcheck + + if resp.StatusCode == http.StatusNotFound { + return []string{}, nil + } + if resp.StatusCode != http.StatusMultiStatus && resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("PROPFIND %q: status %d, body: %s", + dirURL, resp.StatusCode, c.readAndTruncateBody(resp)) + } + + var multi multistatusResponse + if err := xml.NewDecoder(resp.Body).Decode(&multi); err != nil { + return nil, fmt.Errorf("decoding PROPFIND response: %w", err) + } + + parsedDirURL, err := url.Parse(dirURL) + if err != nil { + return nil, fmt.Errorf("parsing dirURL: %w", err) + } + currentPath := strings.TrimSuffix(parsedDirURL.Path, "/") + + var blobs []string + for _, response := range multi.Responses { + hrefURL, err := url.Parse(response.Href) + if err != nil { + slog.Warn("skipping unparseable href in PROPFIND response", "href", response.Href, "error", err) + continue + } + hrefPath := strings.TrimSuffix(hrefURL.Path, "/") + + if hrefPath == currentPath { + continue + } + + if response.isCollection() { + subURL := hrefURL.String() + if !hrefURL.IsAbs() { + subURL = parsedDirURL.ResolveReference(hrefURL).String() + } + sub, err := c.listRecursive(subURL, endpointPath, prefix) + if err != nil { + return nil, err + } + blobs = append(blobs, sub...) + continue + } + + blobID, err := blobIDFromHref(response.Href, endpointPath) + if err != nil { + slog.Warn("skipping href that could not be mapped to a blob ID", "href", response.Href, "error", err) + continue + } + if prefix == "" || strings.HasPrefix(blobID, prefix) { + blobs = append(blobs, blobID) + } + } + + return blobs, nil +} + +// blobIDFromHref extracts the blob ID from a WebDAV href +// Returns the path relative to the endpoint +func blobIDFromHref(href, endpointPath string) (string, error) { + if decoded, err := url.PathUnescape(href); err == nil { + href = decoded + } + + hrefURL, err := url.Parse(href) + if err != nil { + return "", fmt.Errorf("parsing href: %w", err) + } + + hrefPath := strings.TrimPrefix(hrefURL.Path, "/") + endpointClean := strings.Trim(endpointPath, "/") + if endpointClean != "" { + hrefPath = strings.TrimPrefix(hrefPath, endpointClean+"/") + } + + if hrefPath == "" { + return "", fmt.Errorf("href %q has no blob component after stripping endpoint %q", href, endpointPath) + } + return hrefPath, nil +} + +func (c *storageClient) DeleteRecursive(prefix string) error { + blobs, err := c.List(prefix) + if err != nil { + return fmt.Errorf("listing blobs under %q: %w", prefix, err) + } + + if len(blobs) == 0 { + slog.Warn("no blobs found for prefix, nothing deleted", "prefix", prefix) + return nil + } + + for _, blob := range blobs { + if err := c.Delete(blob); err != nil { + return fmt.Errorf("deleting %q: %w", blob, err) + } + } + return nil +} + +// Properties prints the blob's metadata (ETag, Last-Modified, Content-Length) +// as JSON to stdout. Returns nil with `{}` on 404 to mirror the behaviour of +// other backends (S3, Azure) for missing blobs. +func (c *storageClient) Properties(blobPath string) error { + req, err := c.createReq("HEAD", blobPath, nil) + if err != nil { + return fmt.Errorf("creating HEAD request for %q: %w", blobPath, err) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("fetching properties of %q: %w", blobPath, err) + } + defer resp.Body.Close() //nolint:errcheck + + if resp.StatusCode == http.StatusNotFound { + fmt.Println("{}") + return nil + } + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("fetching properties of %q: status %d", blobPath, resp.StatusCode) + } + + props := BlobProperties{} + if resp.ContentLength >= 0 { + props.ContentLength = &resp.ContentLength + } + if etag := resp.Header.Get("ETag"); etag != "" { + props.ETag = strings.Trim(etag, `"`) + } + if lm := resp.Header.Get("Last-Modified"); lm != "" { + if t, err := time.Parse(time.RFC1123, lm); err == nil { + props.LastModified = t + } else { + slog.Warn("could not parse Last-Modified header", "value", lm, "error", err) + } + } + + out, err := json.MarshalIndent(props, "", " ") + if err != nil { + return fmt.Errorf("marshaling properties: %w", err) + } + fmt.Println(string(out)) + return nil +} + +// EnsureStorageExists is a no-op for DAV. WebDAV has no "bucket" concept to +// provision: nginx auto-creates parent directories on first PUT (via +// `create_full_put_path on`), so there is nothing to do here. Matches the +// fog-based Ruby DavClient, whose ensure_bucket_exists is also empty. The +// method exists only to satisfy the StorageClient interface. +func (c *storageClient) EnsureStorageExists() error { + return nil +}