diff --git a/decompression_test.go b/decompression_test.go index cbd0cb8..2be61b3 100644 --- a/decompression_test.go +++ b/decompression_test.go @@ -132,3 +132,35 @@ func Test_decompressTarXz_ErrorWhenFileToCopyToNotExists(t *testing.T) { assert.Regexp(t, "^unable to extract postgres archive:.+$", err) } + +func Test_decompressTarXz_ErrorWhenArchiveCorrupted(t *testing.T) { + tempDir, err := ioutil.TempDir("", "temp_tar_test") + if err != nil { + panic(err) + } + + archive, cleanup := createTempXzArchive() + + defer cleanup() + + file, err := os.OpenFile(archive, os.O_WRONLY, 0664) + if err != nil { + panic(err) + } + + if _, err := file.Seek(35, 0); err != nil { + panic(err) + } + + if _, err := file.WriteString("someJunk"); err != nil { + panic(err) + } + + if err := file.Close(); err != nil { + panic(err) + } + + err = decompressTarXz(defaultTarReader, archive, tempDir) + + assert.EqualError(t, err, "unable to extract postgres archive: xz: data is corrupt") +} diff --git a/remote_fetch.go b/remote_fetch.go index a33a1ee..ba0499e 100644 --- a/remote_fetch.go +++ b/remote_fetch.go @@ -3,6 +3,9 @@ package embeddedpostgres import ( "archive/zip" "bytes" + "crypto/sha256" + "encoding/hex" + "errors" "fmt" "io/ioutil" "log" @@ -19,7 +22,8 @@ type RemoteFetchStrategy func() error func defaultRemoteFetchStrategy(remoteFetchHost string, versionStrategy VersionStrategy, cacheLocator CacheLocator) RemoteFetchStrategy { return func() error { operatingSystem, architecture, version := versionStrategy() - downloadURL := fmt.Sprintf("%s/io/zonky/test/postgres/embedded-postgres-binaries-%s-%s/%s/embedded-postgres-binaries-%s-%s-%s.jar", + + jarDownloadURL := fmt.Sprintf("%s/io/zonky/test/postgres/embedded-postgres-binaries-%s-%s/%s/embedded-postgres-binaries-%s-%s-%s.jar", remoteFetchHost, operatingSystem, architecture, @@ -28,32 +32,50 @@ func defaultRemoteFetchStrategy(remoteFetchHost string, versionStrategy VersionS architecture, version) - resp, err := http.Get(downloadURL) + jarDownloadResponse, err := http.Get(jarDownloadURL) if err != nil { return fmt.Errorf("unable to connect to %s", remoteFetchHost) } - defer func() { - if err := resp.Body.Close(); err != nil { - log.Fatal(err) - } - }() + defer closeBody(jarDownloadResponse)() - if resp.StatusCode != http.StatusOK { + if jarDownloadResponse.StatusCode != http.StatusOK { return fmt.Errorf("no version found matching %s", version) } - return decompressResponse(resp, cacheLocator, downloadURL) + jarBodyBytes, err := ioutil.ReadAll(jarDownloadResponse.Body) + if err != nil { + return errorFetchingPostgres(err) + } + + shaDownloadURL := fmt.Sprintf("%s.sha256", jarDownloadURL) + shaDownloadResponse, err := http.Get(shaDownloadURL) + + defer closeBody(shaDownloadResponse)() + + if err == nil && shaDownloadResponse.StatusCode == http.StatusOK { + if shaBodyBytes, err := ioutil.ReadAll(shaDownloadResponse.Body); err == nil { + jarChecksum := sha256.Sum256(jarBodyBytes) + if !bytes.Equal(shaBodyBytes, []byte(hex.EncodeToString(jarChecksum[:]))) { + return errors.New("downloaded checksums do not match") + } + } + } + + return decompressResponse(jarBodyBytes, jarDownloadResponse.ContentLength, cacheLocator, jarDownloadURL) } } -func decompressResponse(resp *http.Response, cacheLocator CacheLocator, downloadURL string) error { - bodyBytes, err := ioutil.ReadAll(resp.Body) - if err != nil { - return errorFetchingPostgres(err) +func closeBody(resp *http.Response) func() { + return func() { + if err := resp.Body.Close(); err != nil { + log.Fatal(err) + } } +} - zipReader, err := zip.NewReader(bytes.NewReader(bodyBytes), resp.ContentLength) +func decompressResponse(bodyBytes []byte, contentLength int64, cacheLocator CacheLocator, downloadURL string) error { + zipReader, err := zip.NewReader(bytes.NewReader(bodyBytes), contentLength) if err != nil { return errorFetchingPostgres(err) } diff --git a/remote_fetch_test.go b/remote_fetch_test.go index 98588f0..57f7b0a 100644 --- a/remote_fetch_test.go +++ b/remote_fetch_test.go @@ -2,11 +2,14 @@ package embeddedpostgres import ( "archive/zip" + "crypto/sha256" + "encoding/hex" "io/ioutil" "net/http" "net/http/httptest" "os" "path/filepath" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -54,7 +57,10 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenBodyReadIssue(t *testing.T) { func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzipSubFile(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - + if strings.HasSuffix(r.RequestURI, ".sha256") { + w.WriteHeader(http.StatusNotFound) + return + } })) defer server.Close() @@ -69,6 +75,11 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzipSubFile(t *testing.T) { func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzip(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.RequestURI, ".sha256") { + w.WriteHeader(404) + return + } + if _, err := w.Write([]byte("lolz")); err != nil { panic(err) } @@ -86,6 +97,11 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzip(t *testing.T) { func Test_defaultRemoteFetchStrategy_ErrorWhenNoSubTarArchive(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.RequestURI, ".sha256") { + w.WriteHeader(http.StatusNotFound) + return + } + MyZipWriter := zip.NewWriter(w) if err := MyZipWriter.Close(); err != nil { @@ -114,6 +130,11 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotExtractSubArchive(t *testing } server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.RequestURI, ".sha256") { + w.WriteHeader(http.StatusNotFound) + return + } + bytes, err := ioutil.ReadFile(jarFile) if err != nil { panic(err) @@ -148,6 +169,11 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotCreateCacheDirectory(t *test cacheLocation := filepath.Join(fileBlockingExtractDirectory, "cache_file.jar") server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.RequestURI, ".sha256") { + w.WriteHeader(http.StatusNotFound) + return + } + bytes, err := ioutil.ReadFile(jarFile) if err != nil { panic(err) @@ -181,6 +207,11 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotCreateSubArchiveFile(t *test } server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.RequestURI, ".sha256") { + w.WriteHeader(http.StatusNotFound) + return + } + bytes, err := ioutil.ReadFile(jarFile) if err != nil { panic(err) @@ -202,6 +233,44 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotCreateSubArchiveFile(t *test assert.Regexp(t, "^unable to extract postgres archive:.+$", err) } +func Test_defaultRemoteFetchStrategy_ErrorWhenSHA256NotMatch(t *testing.T) { + jarFile, cleanUp := createTempZipArchive() + defer cleanUp() + + cacheLocation := filepath.Join(filepath.Dir(jarFile), "extract_location", "cache.jar") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + bytes, err := ioutil.ReadFile(jarFile) + if err != nil { + panic(err) + } + + if strings.HasSuffix(r.RequestURI, ".sha256") { + w.WriteHeader(200) + if _, err := w.Write([]byte("literallyN3verGonnaWork")); err != nil { + panic(err) + } + + return + } + + if _, err := w.Write(bytes); err != nil { + panic(err) + } + })) + defer server.Close() + + remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2", + testVersionStrategy(), + func() (s string, b bool) { + return cacheLocation, false + }) + + err := remoteFetchStrategy() + + assert.EqualError(t, err, "downloaded checksums do not match") +} + func Test_defaultRemoteFetchStrategy(t *testing.T) { jarFile, cleanUp := createTempZipArchive() defer cleanUp() @@ -213,6 +282,17 @@ func Test_defaultRemoteFetchStrategy(t *testing.T) { if err != nil { panic(err) } + + if strings.HasSuffix(r.RequestURI, ".sha256") { + w.WriteHeader(200) + contentHash := sha256.Sum256(bytes) + if _, err := w.Write([]byte(hex.EncodeToString(contentHash[:]))); err != nil { + panic(err) + } + + return + } + if _, err := w.Write(bytes); err != nil { panic(err) }