Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Add test for real corrupted data and utilise sha256 where available #73

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions decompression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,35 @@ func Test_decompressTarXz_ErrorWhenFileToCopyToNotExists(t *testing.T) {

assert.Regexp(t, "^unable to extract postgres archive:.+$", err)
}

func Test_decompressTarXz_ErrorWhenArchiveCorrupted(t *testing.T) {
tempDir, err := ioutil.TempDir("", "temp_tar_test")
if err != nil {
panic(err)
}

archive, cleanup := createTempXzArchive()

defer cleanup()

file, err := os.OpenFile(archive, os.O_WRONLY, 0664)
if err != nil {
panic(err)
}

if _, err := file.Seek(35, 0); err != nil {
panic(err)
}

if _, err := file.WriteString("someJunk"); err != nil {
panic(err)
}

if err := file.Close(); err != nil {
panic(err)
}

err = decompressTarXz(defaultTarReader, archive, tempDir)

assert.EqualError(t, err, "unable to extract postgres archive: xz: data is corrupt")
}
50 changes: 36 additions & 14 deletions remote_fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package embeddedpostgres
import (
"archive/zip"
"bytes"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io/ioutil"
"log"
Expand All @@ -19,7 +22,8 @@ type RemoteFetchStrategy func() error
func defaultRemoteFetchStrategy(remoteFetchHost string, versionStrategy VersionStrategy, cacheLocator CacheLocator) RemoteFetchStrategy {
return func() error {
operatingSystem, architecture, version := versionStrategy()
downloadURL := fmt.Sprintf("%s/io/zonky/test/postgres/embedded-postgres-binaries-%s-%s/%s/embedded-postgres-binaries-%s-%s-%s.jar",

jarDownloadURL := fmt.Sprintf("%s/io/zonky/test/postgres/embedded-postgres-binaries-%s-%s/%s/embedded-postgres-binaries-%s-%s-%s.jar",
remoteFetchHost,
operatingSystem,
architecture,
Expand All @@ -28,32 +32,50 @@ func defaultRemoteFetchStrategy(remoteFetchHost string, versionStrategy VersionS
architecture,
version)

resp, err := http.Get(downloadURL)
jarDownloadResponse, err := http.Get(jarDownloadURL)
if err != nil {
return fmt.Errorf("unable to connect to %s", remoteFetchHost)
}

defer func() {
if err := resp.Body.Close(); err != nil {
log.Fatal(err)
}
}()
defer closeBody(jarDownloadResponse)()

if resp.StatusCode != http.StatusOK {
if jarDownloadResponse.StatusCode != http.StatusOK {
return fmt.Errorf("no version found matching %s", version)
}

return decompressResponse(resp, cacheLocator, downloadURL)
jarBodyBytes, err := ioutil.ReadAll(jarDownloadResponse.Body)
if err != nil {
return errorFetchingPostgres(err)
}

shaDownloadURL := fmt.Sprintf("%s.sha256", jarDownloadURL)
shaDownloadResponse, err := http.Get(shaDownloadURL)

defer closeBody(shaDownloadResponse)()

if err == nil && shaDownloadResponse.StatusCode == http.StatusOK {
if shaBodyBytes, err := ioutil.ReadAll(shaDownloadResponse.Body); err == nil {
jarChecksum := sha256.Sum256(jarBodyBytes)
if !bytes.Equal(shaBodyBytes, []byte(hex.EncodeToString(jarChecksum[:]))) {
return errors.New("downloaded checksums do not match")
}
}
}

return decompressResponse(jarBodyBytes, jarDownloadResponse.ContentLength, cacheLocator, jarDownloadURL)
}
}

func decompressResponse(resp *http.Response, cacheLocator CacheLocator, downloadURL string) error {
bodyBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
return errorFetchingPostgres(err)
func closeBody(resp *http.Response) func() {
return func() {
if err := resp.Body.Close(); err != nil {
log.Fatal(err)
}
}
}

zipReader, err := zip.NewReader(bytes.NewReader(bodyBytes), resp.ContentLength)
func decompressResponse(bodyBytes []byte, contentLength int64, cacheLocator CacheLocator, downloadURL string) error {
zipReader, err := zip.NewReader(bytes.NewReader(bodyBytes), contentLength)
if err != nil {
return errorFetchingPostgres(err)
}
Expand Down
82 changes: 81 additions & 1 deletion remote_fetch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ package embeddedpostgres

import (
"archive/zip"
"crypto/sha256"
"encoding/hex"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -54,7 +57,10 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenBodyReadIssue(t *testing.T) {

func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzipSubFile(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

if strings.HasSuffix(r.RequestURI, ".sha256") {
w.WriteHeader(http.StatusNotFound)
return
}
}))
defer server.Close()

Expand All @@ -69,6 +75,11 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzipSubFile(t *testing.T) {

func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzip(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.RequestURI, ".sha256") {
w.WriteHeader(404)
return
}

if _, err := w.Write([]byte("lolz")); err != nil {
panic(err)
}
Expand All @@ -86,6 +97,11 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzip(t *testing.T) {

func Test_defaultRemoteFetchStrategy_ErrorWhenNoSubTarArchive(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.RequestURI, ".sha256") {
w.WriteHeader(http.StatusNotFound)
return
}

MyZipWriter := zip.NewWriter(w)

if err := MyZipWriter.Close(); err != nil {
Expand Down Expand Up @@ -114,6 +130,11 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotExtractSubArchive(t *testing
}

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.RequestURI, ".sha256") {
w.WriteHeader(http.StatusNotFound)
return
}

bytes, err := ioutil.ReadFile(jarFile)
if err != nil {
panic(err)
Expand Down Expand Up @@ -148,6 +169,11 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotCreateCacheDirectory(t *test
cacheLocation := filepath.Join(fileBlockingExtractDirectory, "cache_file.jar")

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.RequestURI, ".sha256") {
w.WriteHeader(http.StatusNotFound)
return
}

bytes, err := ioutil.ReadFile(jarFile)
if err != nil {
panic(err)
Expand Down Expand Up @@ -181,6 +207,11 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotCreateSubArchiveFile(t *test
}

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.RequestURI, ".sha256") {
w.WriteHeader(http.StatusNotFound)
return
}

bytes, err := ioutil.ReadFile(jarFile)
if err != nil {
panic(err)
Expand All @@ -202,6 +233,44 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotCreateSubArchiveFile(t *test
assert.Regexp(t, "^unable to extract postgres archive:.+$", err)
}

func Test_defaultRemoteFetchStrategy_ErrorWhenSHA256NotMatch(t *testing.T) {
jarFile, cleanUp := createTempZipArchive()
defer cleanUp()

cacheLocation := filepath.Join(filepath.Dir(jarFile), "extract_location", "cache.jar")

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
bytes, err := ioutil.ReadFile(jarFile)
if err != nil {
panic(err)
}

if strings.HasSuffix(r.RequestURI, ".sha256") {
w.WriteHeader(200)
if _, err := w.Write([]byte("literallyN3verGonnaWork")); err != nil {
panic(err)
}

return
}

if _, err := w.Write(bytes); err != nil {
panic(err)
}
}))
defer server.Close()

remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
testVersionStrategy(),
func() (s string, b bool) {
return cacheLocation, false
})

err := remoteFetchStrategy()

assert.EqualError(t, err, "downloaded checksums do not match")
}

func Test_defaultRemoteFetchStrategy(t *testing.T) {
jarFile, cleanUp := createTempZipArchive()
defer cleanUp()
Expand All @@ -213,6 +282,17 @@ func Test_defaultRemoteFetchStrategy(t *testing.T) {
if err != nil {
panic(err)
}

if strings.HasSuffix(r.RequestURI, ".sha256") {
w.WriteHeader(200)
contentHash := sha256.Sum256(bytes)
if _, err := w.Write([]byte(hex.EncodeToString(contentHash[:]))); err != nil {
panic(err)
}

return
}

if _, err := w.Write(bytes); err != nil {
panic(err)
}
Expand Down