Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 51 additions & 5 deletions pkg/cmd/attestation/api/client.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package api

import (
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"

"github.com/cenkalti/backoff/v4"
"github.com/cli/cli/v2/api"
ioconfig "github.com/cli/cli/v2/pkg/cmd/attestation/io"
)
Expand Down Expand Up @@ -69,6 +72,9 @@ func (c *LiveClient) GetTrustDomain() (string, error) {
return c.getTrustDomain(MetaPath)
}

// Allow injecting backoff interval in tests.
var getAttestationRetryInterval = time.Millisecond * 200

func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*Attestation, error) {
c.logger.VerbosePrintf("Fetching attestations for artifact digest %s\n\n", digest)

Expand All @@ -86,15 +92,31 @@ func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*At

var attestations []*Attestation
var resp AttestationsResponse
var err error
bo := backoff.NewConstantBackOff(getAttestationRetryInterval)

// if no attestation or less than limit, then keep fetching
for url != "" && len(attestations) < limit {
url, err = c.api.RESTWithNext(c.host, http.MethodGet, url, nil, &resp)
err := backoff.Retry(func() error {
newURL, restErr := c.api.RESTWithNext(c.host, http.MethodGet, url, nil, &resp)

if restErr != nil {
if shouldRetry(restErr) {
return restErr
} else {
return backoff.Permanent(restErr)
}
}

url = newURL
attestations = append(attestations, resp.Attestations...)

return nil
}, backoff.WithMaxRetries(bo, 3))

// bail if RESTWithNext errored out
if err != nil {
return nil, err
}

attestations = append(attestations, resp.Attestations...)
}

if len(attestations) == 0 {
Expand All @@ -108,10 +130,34 @@ func (c *LiveClient) getAttestations(url, name, digest string, limit int) ([]*At
return attestations, nil
}

func shouldRetry(err error) bool {
var httpError api.HTTPError
if errors.As(err, &httpError) {
if httpError.StatusCode >= 500 && httpError.StatusCode <= 599 {
return true
}
}

return false
}

func (c *LiveClient) getTrustDomain(url string) (string, error) {
var resp MetaResponse

err := c.api.REST(c.host, http.MethodGet, url, nil, &resp)
bo := backoff.NewConstantBackOff(getAttestationRetryInterval)
err := backoff.Retry(func() error {
restErr := c.api.REST(c.host, http.MethodGet, url, nil, &resp)
if restErr != nil {
if shouldRetry(restErr) {
return restErr
} else {
return backoff.Permanent(restErr)
}
}

return nil
}, backoff.WithMaxRetries(bo, 3))

if err != nil {
return "", err
}
Expand Down
59 changes: 59 additions & 0 deletions pkg/cmd/attestation/api/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,62 @@ func TestGetTrustDomain(t *testing.T) {
})

}

func TestGetAttestationsRetries(t *testing.T) {
getAttestationRetryInterval = 0

fetcher := mockDataGenerator{
NumAttestations: 5,
}

c := &LiveClient{
api: mockAPIClient{
OnRESTWithNext: fetcher.FlakyOnRESTSuccessWithNextPageHandler(),
},
logger: io.NewTestHandler(),
}

attestations, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit)
require.NoError(t, err)

// assert the error path was executed; because this is a paged
// request, it should have errored twice
fetcher.AssertNumberOfCalls(t, "FlakyOnRESTSuccessWithNextPage:error", 2)

// but we still successfully got the right data
require.Equal(t, len(attestations), 10)
bundle := (attestations)[0].Bundle
require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json")

// same test as above, but for GetByOwnerAndDigest:
attestations, err = c.GetByOwnerAndDigest(testOwner, testDigest, DefaultLimit)
require.NoError(t, err)

// because we haven't reset the mock, we have added 2 more failed requests
fetcher.AssertNumberOfCalls(t, "FlakyOnRESTSuccessWithNextPage:error", 4)

require.Equal(t, len(attestations), 10)
bundle = (attestations)[0].Bundle
require.Equal(t, bundle.GetMediaType(), "application/vnd.dev.sigstore.bundle.v0.3+json")
}

// test total retries
func TestGetAttestationsMaxRetries(t *testing.T) {
getAttestationRetryInterval = 0

fetcher := mockDataGenerator{
NumAttestations: 5,
}

c := &LiveClient{
api: mockAPIClient{
OnRESTWithNext: fetcher.OnREST500ErrorHandler(),
},
logger: io.NewTestHandler(),
}

_, err := c.GetByRepoAndDigest(testRepo, testDigest, DefaultLimit)
require.Error(t, err)

fetcher.AssertNumberOfCalls(t, "OnREST500Error", 4)
}
45 changes: 40 additions & 5 deletions pkg/cmd/attestation/api/mock_apiClient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ import (
"fmt"
"io"
"strings"

cliAPI "github.com/cli/cli/v2/api"
ghAPI "github.com/cli/go-gh/v2/pkg/api"
"github.com/stretchr/testify/mock"
)

type mockAPIClient struct {
Expand All @@ -22,14 +26,15 @@ func (m mockAPIClient) REST(hostname, method, p string, body io.Reader, data int
}

type mockDataGenerator struct {
mock.Mock
NumAttestations int
}

func (m mockDataGenerator) OnRESTSuccess(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
func (m *mockDataGenerator) OnRESTSuccess(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
return m.OnRESTWithNextSuccessHelper(hostname, method, p, body, data, false)
}

func (m mockDataGenerator) OnRESTSuccessWithNextPage(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
func (m *mockDataGenerator) OnRESTSuccessWithNextPage(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
// if path doesn't contain after, it means first time hitting the mock server
// so return the first page and return the link header in the response
if !strings.Contains(p, "after") {
Expand All @@ -40,7 +45,37 @@ func (m mockDataGenerator) OnRESTSuccessWithNextPage(hostname, method, p string,
return m.OnRESTWithNextSuccessHelper(hostname, method, p, body, data, false)
}

func (m mockDataGenerator) OnRESTWithNextSuccessHelper(hostname, method, p string, body io.Reader, data interface{}, hasNext bool) (string, error) {
// Returns a func that just calls OnRESTSuccessWithNextPage but half the time
// it returns a 500 error.
func (m *mockDataGenerator) FlakyOnRESTSuccessWithNextPageHandler() func(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
// set up the flake counter
m.On("FlakyOnRESTSuccessWithNextPage:error").Return()

count := 0
return func(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
if count%2 == 0 {
m.MethodCalled("FlakyOnRESTSuccessWithNextPage:error")

count = count + 1
return "", cliAPI.HTTPError{HTTPError: &ghAPI.HTTPError{StatusCode: 500}}
} else {
count = count + 1
return m.OnRESTSuccessWithNextPage(hostname, method, p, body, data)
}
}
}

// always returns a 500
func (m *mockDataGenerator) OnREST500ErrorHandler() func(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
m.On("OnREST500Error").Return()
return func(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
m.MethodCalled("OnREST500Error")

return "", cliAPI.HTTPError{HTTPError: &ghAPI.HTTPError{StatusCode: 500}}
}
}

func (m *mockDataGenerator) OnRESTWithNextSuccessHelper(hostname, method, p string, body io.Reader, data interface{}, hasNext bool) (string, error) {
atts := make([]*Attestation, m.NumAttestations)
for j := 0; j < m.NumAttestations; j++ {
att := makeTestAttestation()
Expand Down Expand Up @@ -70,7 +105,7 @@ func (m mockDataGenerator) OnRESTWithNextSuccessHelper(hostname, method, p strin
return "", nil
}

func (m mockDataGenerator) OnRESTWithNextNoAttestations(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
func (m *mockDataGenerator) OnRESTWithNextNoAttestations(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
resp := AttestationsResponse{
Attestations: make([]*Attestation, 0),
}
Expand All @@ -89,7 +124,7 @@ func (m mockDataGenerator) OnRESTWithNextNoAttestations(hostname, method, p stri
return "", nil
}

func (m mockDataGenerator) OnRESTWithNextError(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
func (m *mockDataGenerator) OnRESTWithNextError(hostname, method, p string, body io.Reader, data interface{}) (string, error) {
return "", errors.New("failed to get attestations")
}

Expand Down