Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 1d2eeea

Browse files
Add test for real corrupted data and utilise sha256 where available (fergusstrange#73)
* Add test for real corrupted data. * Try adding integrity check on downloaded files. * Lint * Add coverage for checking sha * Up coverage * Check via bytes rather than string and tidy code
1 parent 9f87ef1 commit 1d2eeea

File tree

3 files changed

+149
-15
lines changed

3 files changed

+149
-15
lines changed

decompression_test.go

+32
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,35 @@ func Test_decompressTarXz_ErrorWhenFileToCopyToNotExists(t *testing.T) {
132132

133133
assert.Regexp(t, "^unable to extract postgres archive:.+$", err)
134134
}
135+
136+
func Test_decompressTarXz_ErrorWhenArchiveCorrupted(t *testing.T) {
137+
tempDir, err := ioutil.TempDir("", "temp_tar_test")
138+
if err != nil {
139+
panic(err)
140+
}
141+
142+
archive, cleanup := createTempXzArchive()
143+
144+
defer cleanup()
145+
146+
file, err := os.OpenFile(archive, os.O_WRONLY, 0664)
147+
if err != nil {
148+
panic(err)
149+
}
150+
151+
if _, err := file.Seek(35, 0); err != nil {
152+
panic(err)
153+
}
154+
155+
if _, err := file.WriteString("someJunk"); err != nil {
156+
panic(err)
157+
}
158+
159+
if err := file.Close(); err != nil {
160+
panic(err)
161+
}
162+
163+
err = decompressTarXz(defaultTarReader, archive, tempDir)
164+
165+
assert.EqualError(t, err, "unable to extract postgres archive: xz: data is corrupt")
166+
}

remote_fetch.go

+36-14
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ package embeddedpostgres
33
import (
44
"archive/zip"
55
"bytes"
6+
"crypto/sha256"
7+
"encoding/hex"
8+
"errors"
69
"fmt"
710
"io/ioutil"
811
"log"
@@ -19,7 +22,8 @@ type RemoteFetchStrategy func() error
1922
func defaultRemoteFetchStrategy(remoteFetchHost string, versionStrategy VersionStrategy, cacheLocator CacheLocator) RemoteFetchStrategy {
2023
return func() error {
2124
operatingSystem, architecture, version := versionStrategy()
22-
downloadURL := fmt.Sprintf("%s/io/zonky/test/postgres/embedded-postgres-binaries-%s-%s/%s/embedded-postgres-binaries-%s-%s-%s.jar",
25+
26+
jarDownloadURL := fmt.Sprintf("%s/io/zonky/test/postgres/embedded-postgres-binaries-%s-%s/%s/embedded-postgres-binaries-%s-%s-%s.jar",
2327
remoteFetchHost,
2428
operatingSystem,
2529
architecture,
@@ -28,32 +32,50 @@ func defaultRemoteFetchStrategy(remoteFetchHost string, versionStrategy VersionS
2832
architecture,
2933
version)
3034

31-
resp, err := http.Get(downloadURL)
35+
jarDownloadResponse, err := http.Get(jarDownloadURL)
3236
if err != nil {
3337
return fmt.Errorf("unable to connect to %s", remoteFetchHost)
3438
}
3539

36-
defer func() {
37-
if err := resp.Body.Close(); err != nil {
38-
log.Fatal(err)
39-
}
40-
}()
40+
defer closeBody(jarDownloadResponse)()
4141

42-
if resp.StatusCode != http.StatusOK {
42+
if jarDownloadResponse.StatusCode != http.StatusOK {
4343
return fmt.Errorf("no version found matching %s", version)
4444
}
4545

46-
return decompressResponse(resp, cacheLocator, downloadURL)
46+
jarBodyBytes, err := ioutil.ReadAll(jarDownloadResponse.Body)
47+
if err != nil {
48+
return errorFetchingPostgres(err)
49+
}
50+
51+
shaDownloadURL := fmt.Sprintf("%s.sha256", jarDownloadURL)
52+
shaDownloadResponse, err := http.Get(shaDownloadURL)
53+
54+
defer closeBody(shaDownloadResponse)()
55+
56+
if err == nil && shaDownloadResponse.StatusCode == http.StatusOK {
57+
if shaBodyBytes, err := ioutil.ReadAll(shaDownloadResponse.Body); err == nil {
58+
jarChecksum := sha256.Sum256(jarBodyBytes)
59+
if !bytes.Equal(shaBodyBytes, []byte(hex.EncodeToString(jarChecksum[:]))) {
60+
return errors.New("downloaded checksums do not match")
61+
}
62+
}
63+
}
64+
65+
return decompressResponse(jarBodyBytes, jarDownloadResponse.ContentLength, cacheLocator, jarDownloadURL)
4766
}
4867
}
4968

50-
func decompressResponse(resp *http.Response, cacheLocator CacheLocator, downloadURL string) error {
51-
bodyBytes, err := ioutil.ReadAll(resp.Body)
52-
if err != nil {
53-
return errorFetchingPostgres(err)
69+
func closeBody(resp *http.Response) func() {
70+
return func() {
71+
if err := resp.Body.Close(); err != nil {
72+
log.Fatal(err)
73+
}
5474
}
75+
}
5576

56-
zipReader, err := zip.NewReader(bytes.NewReader(bodyBytes), resp.ContentLength)
77+
func decompressResponse(bodyBytes []byte, contentLength int64, cacheLocator CacheLocator, downloadURL string) error {
78+
zipReader, err := zip.NewReader(bytes.NewReader(bodyBytes), contentLength)
5779
if err != nil {
5880
return errorFetchingPostgres(err)
5981
}

remote_fetch_test.go

+81-1
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@ package embeddedpostgres
22

33
import (
44
"archive/zip"
5+
"crypto/sha256"
6+
"encoding/hex"
57
"io/ioutil"
68
"net/http"
79
"net/http/httptest"
810
"os"
911
"path/filepath"
12+
"strings"
1013
"testing"
1114

1215
"github.com/stretchr/testify/assert"
@@ -54,7 +57,10 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenBodyReadIssue(t *testing.T) {
5457

5558
func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzipSubFile(t *testing.T) {
5659
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
57-
60+
if strings.HasSuffix(r.RequestURI, ".sha256") {
61+
w.WriteHeader(http.StatusNotFound)
62+
return
63+
}
5864
}))
5965
defer server.Close()
6066

@@ -69,6 +75,11 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzipSubFile(t *testing.T) {
6975

7076
func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzip(t *testing.T) {
7177
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
78+
if strings.HasSuffix(r.RequestURI, ".sha256") {
79+
w.WriteHeader(404)
80+
return
81+
}
82+
7283
if _, err := w.Write([]byte("lolz")); err != nil {
7384
panic(err)
7485
}
@@ -86,6 +97,11 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzip(t *testing.T) {
8697

8798
func Test_defaultRemoteFetchStrategy_ErrorWhenNoSubTarArchive(t *testing.T) {
8899
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
100+
if strings.HasSuffix(r.RequestURI, ".sha256") {
101+
w.WriteHeader(http.StatusNotFound)
102+
return
103+
}
104+
89105
MyZipWriter := zip.NewWriter(w)
90106

91107
if err := MyZipWriter.Close(); err != nil {
@@ -114,6 +130,11 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotExtractSubArchive(t *testing
114130
}
115131

116132
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
133+
if strings.HasSuffix(r.RequestURI, ".sha256") {
134+
w.WriteHeader(http.StatusNotFound)
135+
return
136+
}
137+
117138
bytes, err := ioutil.ReadFile(jarFile)
118139
if err != nil {
119140
panic(err)
@@ -148,6 +169,11 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotCreateCacheDirectory(t *test
148169
cacheLocation := filepath.Join(fileBlockingExtractDirectory, "cache_file.jar")
149170

150171
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
172+
if strings.HasSuffix(r.RequestURI, ".sha256") {
173+
w.WriteHeader(http.StatusNotFound)
174+
return
175+
}
176+
151177
bytes, err := ioutil.ReadFile(jarFile)
152178
if err != nil {
153179
panic(err)
@@ -181,6 +207,11 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotCreateSubArchiveFile(t *test
181207
}
182208

183209
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
210+
if strings.HasSuffix(r.RequestURI, ".sha256") {
211+
w.WriteHeader(http.StatusNotFound)
212+
return
213+
}
214+
184215
bytes, err := ioutil.ReadFile(jarFile)
185216
if err != nil {
186217
panic(err)
@@ -202,6 +233,44 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotCreateSubArchiveFile(t *test
202233
assert.Regexp(t, "^unable to extract postgres archive:.+$", err)
203234
}
204235

236+
func Test_defaultRemoteFetchStrategy_ErrorWhenSHA256NotMatch(t *testing.T) {
237+
jarFile, cleanUp := createTempZipArchive()
238+
defer cleanUp()
239+
240+
cacheLocation := filepath.Join(filepath.Dir(jarFile), "extract_location", "cache.jar")
241+
242+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
243+
bytes, err := ioutil.ReadFile(jarFile)
244+
if err != nil {
245+
panic(err)
246+
}
247+
248+
if strings.HasSuffix(r.RequestURI, ".sha256") {
249+
w.WriteHeader(200)
250+
if _, err := w.Write([]byte("literallyN3verGonnaWork")); err != nil {
251+
panic(err)
252+
}
253+
254+
return
255+
}
256+
257+
if _, err := w.Write(bytes); err != nil {
258+
panic(err)
259+
}
260+
}))
261+
defer server.Close()
262+
263+
remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
264+
testVersionStrategy(),
265+
func() (s string, b bool) {
266+
return cacheLocation, false
267+
})
268+
269+
err := remoteFetchStrategy()
270+
271+
assert.EqualError(t, err, "downloaded checksums do not match")
272+
}
273+
205274
func Test_defaultRemoteFetchStrategy(t *testing.T) {
206275
jarFile, cleanUp := createTempZipArchive()
207276
defer cleanUp()
@@ -213,6 +282,17 @@ func Test_defaultRemoteFetchStrategy(t *testing.T) {
213282
if err != nil {
214283
panic(err)
215284
}
285+
286+
if strings.HasSuffix(r.RequestURI, ".sha256") {
287+
w.WriteHeader(200)
288+
contentHash := sha256.Sum256(bytes)
289+
if _, err := w.Write([]byte(hex.EncodeToString(contentHash[:]))); err != nil {
290+
panic(err)
291+
}
292+
293+
return
294+
}
295+
216296
if _, err := w.Write(bytes); err != nil {
217297
panic(err)
218298
}

0 commit comments

Comments
 (0)