diff --git a/internal/container/controller.go b/internal/container/controller.go index 4838eecc..82e18ead 100644 --- a/internal/container/controller.go +++ b/internal/container/controller.go @@ -81,6 +81,10 @@ func (c Controller) ContainerRun( unitTestFailure bool, ) string { hostConfig := &container.HostConfig{ + // On Linux Docker Engine, host.docker.internal is not set by default + // (Docker Desktop sets it automatically). Add it so containers can + // reach host services, e.g. for bak-file restore from a local HTTP server. + ExtraHosts: []string{"host.docker.internal:host-gateway"}, PortBindings: network.PortMap{ network.MustParsePort("1433/tcp"): []network.PortBinding{ { @@ -250,23 +254,30 @@ func (c Controller) DownloadFile(id string, src string, destFolder string) { panic("Must pass in non-empty destFolder") } - cmd := []string{"mkdir", destFolder} - c.runCmdInContainer(id, cmd) - _, file := filepath.Split(src) + if file == "" { + panic("src URL has no filename: " + src) + } + dest := destFolder + "/" + file // not using filepath.Join here, this is in the *nix container. always / - // Wget the .bak file from the http src, and place it in /var/opt/sql/backup - cmd = []string{ - "wget", - "-O", - destFolder + "/" + file, // not using filepath.Join here, this is in the *nix container. always / - src, + cmd := []string{"mkdir", "-p", destFolder} + _, _, mkdirExit := c.runCmdInContainer(id, cmd) + if mkdirExit != 0 { + panic("mkdir failed for " + destFolder) } - c.runCmdInContainer(id, cmd) + cmd = []string{"wget", "-O", dest, src} + _, stderr, exitCode := c.runCmdInContainer(id, cmd) + if exitCode != 0 { + msg := "download failed: " + src + if len(stderr) > 0 { + msg += "\nwget output: " + string(stderr) + } + panic(msg) + } } -func (c Controller) runCmdInContainer(id string, cmd []string) ([]byte, []byte) { +func (c Controller) runCmdInContainer(id string, cmd []string) ([]byte, []byte, int) { trace("Running command in container: " + strings.Join(cmd, " ")) response, err := c.cli.ExecCreate( @@ -308,7 +319,17 @@ func (c Controller) runCmdInContainer(id string, cmd []string) ([]byte, []byte) trace("Stdout: " + string(stdout)) trace("Stderr: " + string(stderr)) - return stdout, stderr + // ExecInspect may rarely return Running:true after output is drained + // (moby/moby#42408). In practice the race window is negligible for + // short-lived commands like mkdir and wget. + inspect, err := c.cli.ExecInspect( + context.Background(), + response.ID, + client.ExecInspectOptions{}, + ) + checkErr(err) + + return stdout, stderr, inspect.ExitCode } // ContainerRunning returns true if the container with the given ID is running. diff --git a/internal/container/controller_test.go b/internal/container/controller_test.go index 347d7aec..24decfe5 100644 --- a/internal/container/controller_test.go +++ b/internal/container/controller_test.go @@ -5,10 +5,12 @@ package container import ( "fmt" - "github.com/stretchr/testify/assert" + "net" "net/http" "net/http/httptest" "testing" + + "github.com/stretchr/testify/assert" ) func TestController_ListTags(t *testing.T) { @@ -49,12 +51,23 @@ func TestController_EnsureImage(t *testing.T) { c.ContainerExists(id) c.ContainerFiles(id, "*.mdf") - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Bind to 0.0.0.0 so the container can reach the server via the + // Docker bridge network (host.docker.internal resolves to 172.17.0.1). + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("test")) })) + l, err := net.Listen("tcp4", "0.0.0.0:0") + checkErr(err) + ts.Listener = l + ts.Start() defer ts.Close() - c.DownloadFile(id, ts.URL, "test.txt") + // Build URL from listener port so it works regardless of whether + // the OS returns 127.0.0.1, localhost, or [::] in ts.URL. + _, tsPort, _ := net.SplitHostPort(ts.Listener.Addr().String()) + tsURL := fmt.Sprintf("http://host.docker.internal:%s", tsPort) + + c.DownloadFile(id, tsURL+"/test.bak", "/tmp") err = c.ContainerStop(id) checkErr(err) @@ -187,3 +200,10 @@ func TestController_DownloadFileNeg3(t *testing.T) { c.DownloadFile("not_blank", "not_blank", "") }) } + +func TestController_DownloadFileNoFilename(t *testing.T) { + c := NewController() + assert.Panics(t, func() { + c.DownloadFile("not_blank", "http://host:9999/", "/tmp") + }) +}