diff --git a/acceptance/testdata/pr/pr-checkout-with-url-from-fork.txtar b/acceptance/testdata/pr/pr-checkout-with-url-from-fork.txtar index 9a0494f4bb6..637422a5aef 100644 --- a/acceptance/testdata/pr/pr-checkout-with-url-from-fork.txtar +++ b/acceptance/testdata/pr/pr-checkout-with-url-from-fork.txtar @@ -12,6 +12,7 @@ defer gh repo delete --yes ${ORG}/${REPO} # Create a fork exec gh repo fork ${ORG}/${REPO} --org ${ORG} --fork-name ${REPO}-fork +sleep 5 # Defer fork cleanup defer gh repo delete --yes ${ORG}/${REPO}-fork diff --git a/acceptance/testdata/pr/pr-create-guesses-remote-from-sha.txtar b/acceptance/testdata/pr/pr-create-guesses-remote-from-sha.txtar new file mode 100644 index 00000000000..52579b50190 --- /dev/null +++ b/acceptance/testdata/pr/pr-create-guesses-remote-from-sha.txtar @@ -0,0 +1,46 @@ +env REPO=${SCRIPT_NAME}-${RANDOM_STRING} +env FORK=${REPO}-fork + +# Use gh as a credential helper +exec gh auth setup-git + +# Get the current username for the fork owner +exec gh api user --jq .login +stdout2env USER + +# Create a repository with a file so it has a default branch +exec gh repo create ${ORG}/${REPO} --add-readme --private + +# Defer repo cleanup +defer gh repo delete --yes ${ORG}/${REPO} + +# Create a user fork of repository. This will be owned by USER. +exec gh repo fork ${ORG}/${REPO} --fork-name ${FORK} +sleep 5 + +# Defer repo cleanup of fork +defer gh repo delete --yes ${USER}/${FORK} + +# Retrieve fork repository information +exec gh repo view ${USER}/${FORK} --json id --jq '.id' +stdout2env FORK_ID + +exec gh repo clone ${USER}/${FORK} +cd ${FORK} + +# Prepare a branch to commit +exec git checkout -b feature-branch +exec git commit --allow-empty -m 'Upstream Commit' +exec git push upstream feature-branch + +# Prepare an additional commit +exec git commit --allow-empty -m 'Fork Commit' +exec git push origin feature-branch + +# Create the PR +exec gh pr create --title 'Feature Title' --body 'Feature Body' +stdout https://${GH_HOST}/${ORG}/${REPO}/pull/1 + +# Check the PR is indeed created +exec gh pr view ${USER}:feature-branch --json headRefName,headRepository,baseRefName,isCrossRepository +stdout {"baseRefName":"main","headRefName":"feature-branch","headRepository":{"id":"${FORK_ID}","name":"${FORK}"},"isCrossRepository":true} diff --git a/acceptance/testdata/pr/pr-create-no-local-repo.txtar b/acceptance/testdata/pr/pr-create-no-local-repo.txtar new file mode 100644 index 00000000000..cb42d99f829 --- /dev/null +++ b/acceptance/testdata/pr/pr-create-no-local-repo.txtar @@ -0,0 +1,27 @@ +# Use gh as a credential helper +exec gh auth setup-git + +# Setup environment variables used for testscript +env REPO=${SCRIPT_NAME}-${RANDOM_STRING} + +# Create a repository with a file so it has a default branch +exec gh repo create ${ORG}/${REPO} --add-readme --private + +# Defer repo cleanup +defer gh repo delete --yes ${ORG}/${REPO} + +# Clone the repo +exec gh repo clone ${ORG}/${REPO} + +# Prepare a branch to PR +cd ${REPO} +exec git checkout -b feature-branch +exec git commit --allow-empty -m 'Empty Commit' +exec git push -u origin feature-branch + +# Leave the repo so there's no local repo +cd ${WORK} + +# Create the PR +exec gh pr create --title 'Feature Title' --body 'Feature Body' --repo ${ORG}/${REPO} --head feature-branch +stdout https://${GH_HOST}/${ORG}/${REPO}/pull/1 \ No newline at end of file diff --git a/acceptance/testdata/pr/pr-create-respects-branch-pushremote.txtar b/acceptance/testdata/pr/pr-create-respects-branch-pushremote.txtar new file mode 100644 index 00000000000..e0d0c099cd7 --- /dev/null +++ b/acceptance/testdata/pr/pr-create-respects-branch-pushremote.txtar @@ -0,0 +1,49 @@ +skip 'it creates a fork owned by the user running the test' + +# Setup environment variables used for testscript +env REPO=${SCRIPT_NAME}-${RANDOM_STRING} +env FORK=${REPO}-fork + +# Use gh as a credential helper +exec gh auth setup-git + +# Get the current username for the fork owner +exec gh api user --jq .login +stdout2env USER + +# Create a repository to act as upstream with a file so it has a default branch +exec gh repo create ${ORG}/${REPO} --add-readme --private + +# Defer repo cleanup of upstream +defer gh repo delete --yes ${ORG}/${REPO} + +# Create a user fork of repository. This will be owned by USER. +exec gh repo fork ${ORG}/${REPO} --fork-name ${FORK} +sleep 5 + +# Defer repo cleanup of fork +defer gh repo delete --yes ${USER}/${FORK} + +# Retrieve fork repository information +exec gh repo view ${USER}/${FORK} --json id --jq '.id' +stdout2env FORK_ID + +# Clone the repo +exec gh repo clone ${USER}/${FORK} +cd ${FORK} + +# Prepare a branch where changes are pulled from the upstream default branch but pushed to fork +exec git checkout -b feature-branch +exec git branch --set-upstream-to upstream/main +exec git config branch.feature-branch.pushRemote origin +exec git config unset remote.upstream.gh-resolved +exec git commit --allow-empty -m 'Empty Commit' +exec git push + +# Create the PR spanning upstream and fork repositories +exec gh pr create --title 'Feature Title' --body 'Feature Body' +stdout https://${GH_HOST}/${ORG}/${REPO}/pull/1 + +# Assert that the PR was created with the correct head repository and refs +exec gh pr view --json headRefName,headRepository,baseRefName,isCrossRepository +stdout {"baseRefName":"main","headRefName":"feature-branch","headRepository":{"id":"${FORK_ID}","name":"${FORK}"},"isCrossRepository":true} diff --git a/acceptance/testdata/pr/pr-create-respects-push-destination.txtar b/acceptance/testdata/pr/pr-create-respects-push-destination.txtar new file mode 100644 index 00000000000..51708405d8f --- /dev/null +++ b/acceptance/testdata/pr/pr-create-respects-push-destination.txtar @@ -0,0 +1,53 @@ +skip 'it creates a fork owned by the user running the test' + +# Setup environment variables used for testscript +env REPO=${SCRIPT_NAME}-${RANDOM_STRING} +env FORK=${REPO}-fork + +# Use gh as a credential helper +exec gh auth setup-git + +# Get the current username for the fork owner +exec gh api user --jq .login +stdout2env USER + +# Create a repository to act as upstream with a file so it has a default branch +exec gh repo create ${ORG}/${REPO} --add-readme --private + +# Defer repo cleanup of upstream +defer gh repo delete --yes ${ORG}/${REPO} + +# Create a user fork of repository. This will be owned by USER. +exec gh repo fork ${ORG}/${REPO} --fork-name ${FORK} +sleep 5 + +# Defer repo cleanup of fork +defer gh repo delete --yes ${USER}/${FORK} + +# Retrieve fork repository information +exec gh repo view ${USER}/${FORK} --json id --jq '.id' +stdout2env FORK_ID + +# Clone the repo +exec gh repo clone ${USER}/${FORK} +cd ${FORK} + +# Configure default push behavior so local and remote branches will be the same +exec git config push.default current + +# Prepare a branch where changes are pulled from the default branch instead of remote branch of same name +exec git checkout -b feature-branch +exec git branch --set-upstream-to origin/main +exec git rev-parse --abbrev-ref feature-branch@{upstream} +stdout origin/main +exec git config unset remote.upstream.gh-resolved +exec git commit --allow-empty -m 'Empty Commit' +exec git push + +# Create the PR +exec gh pr create --title 'Feature Title' --body 'Feature Body' +stdout https://${GH_HOST}/${ORG}/${REPO}/pull/1 + +# Assert that the PR was created with the correct head repository and refs +exec gh pr view --json headRefName,headRepository,baseRefName,isCrossRepository +stdout {"baseRefName":"main","headRefName":"feature-branch","headRepository":{"id":"${FORK_ID}","name":"${FORK}"},"isCrossRepository":true} diff --git a/acceptance/testdata/pr/pr-create-respects-remote-pushdefault.txtar b/acceptance/testdata/pr/pr-create-respects-remote-pushdefault.txtar new file mode 100644 index 00000000000..ff92f1e2d49 --- /dev/null +++ b/acceptance/testdata/pr/pr-create-respects-remote-pushdefault.txtar @@ -0,0 +1,49 @@ +skip 'it creates a fork owned by the user running the test' + +# Setup environment variables used for testscript +env REPO=${SCRIPT_NAME}-${RANDOM_STRING} +env FORK=${REPO}-fork + +# Use gh as a credential helper +exec gh auth setup-git + +# Get the current username for the fork owner +exec gh api user --jq .login +stdout2env USER + +# Create a repository to act as upstream with a file so it has a default branch +exec gh repo create ${ORG}/${REPO} --add-readme --private + +# Defer repo cleanup of upstream +defer gh repo delete --yes ${ORG}/${REPO} + +# Create a user fork of repository. This will be owned by USER. +exec gh repo fork ${ORG}/${REPO} --fork-name ${FORK} +sleep 5 + +# Defer repo cleanup of fork +defer gh repo delete --yes ${USER}/${FORK} + +# Retrieve fork repository information +exec gh repo view ${USER}/${FORK} --json id --jq '.id' +stdout2env FORK_ID + +# Clone the repo +exec gh repo clone ${USER}/${FORK} +cd ${FORK} + +# Prepare a branch where changes are pulled from the upstream default branch but pushed to fork +exec git checkout -b feature-branch +exec git branch --set-upstream-to upstream/main +exec git config remote.pushDefault origin +exec git config unset remote.upstream.gh-resolved +exec git commit --allow-empty -m 'Empty Commit' +exec git push + +# Create the PR spanning upstream and fork repositories +exec gh pr create --title 'Feature Title' --body 'Feature Body' +stdout https://${GH_HOST}/${ORG}/${REPO}/pull/1 + +# Assert that the PR was created with the correct head repository and refs +exec gh pr view --json headRefName,headRepository,baseRefName,isCrossRepository +stdout {"baseRefName":"main","headRefName":"feature-branch","headRepository":{"id":"${FORK_ID}","name":"${FORK}"},"isCrossRepository":true} diff --git a/acceptance/testdata/pr/pr-create-respects-simple-pushdefault.txtar b/acceptance/testdata/pr/pr-create-respects-simple-pushdefault.txtar new file mode 100644 index 00000000000..63d3ae2b41e --- /dev/null +++ b/acceptance/testdata/pr/pr-create-respects-simple-pushdefault.txtar @@ -0,0 +1,34 @@ +# Setup environment variables used for testscript +env REPO=${SCRIPT_NAME}-${RANDOM_STRING} + +# Use gh as a credential helper +exec gh auth setup-git + +# Create a repository with a file so it has a default branch +exec gh repo create ${ORG}/${REPO} --add-readme --private + +# Defer repo cleanup of repo +defer gh repo delete --yes ${ORG}/${REPO} +exec gh repo view ${ORG}/${REPO} --json id --jq '.id' +stdout2env REPO_ID + +# Clone the repo +exec gh repo clone ${ORG}/${REPO} +cd ${REPO} + +# Configure default push behavior so local and remote branches have to be the same +exec git config push.default simple + +# Prepare a branch where changes are pulled from the default branch instead of remote branch of same name +exec git checkout -b feature-branch +exec git branch --set-upstream-to origin/main +exec git commit --allow-empty -m 'Empty Commit' +exec git push origin feature-branch + +# Create the PR +exec gh pr create --title 'Feature Title' --body 'Feature Body' +stdout https://${GH_HOST}/${ORG}/${REPO}/pull/1 + +# Assert that the PR was created with the correct head repository and refs +exec gh pr view --json headRefName,headRepository,baseRefName,isCrossRepository +stdout {"baseRefName":"main","headRefName":"feature-branch","headRepository":{"id":"${REPO_ID}","name":"${REPO}"},"isCrossRepository":false} diff --git a/acceptance/testdata/pr/pr-create-respects-user-colon-branch-syntax.txtar b/acceptance/testdata/pr/pr-create-respects-user-colon-branch-syntax.txtar new file mode 100644 index 00000000000..a59171d5899 --- /dev/null +++ b/acceptance/testdata/pr/pr-create-respects-user-colon-branch-syntax.txtar @@ -0,0 +1,47 @@ +skip 'it creates a fork owned by the user running the test' + +# Setup environment variables used for testscript +env REPO=${SCRIPT_NAME}-${RANDOM_STRING} +env FORK=${REPO}-fork + +# Use gh as a credential helper +exec gh auth setup-git + +# Get the current username for the fork owner +exec gh api user --jq .login +stdout2env USER + +# Create a repository to act as upstream with a file so it has a default branch +exec gh repo create ${ORG}/${REPO} --add-readme --private + +# Defer repo cleanup of upstream +defer gh repo delete --yes ${ORG}/${REPO} + +# Create a user fork of repository. This will be owned by USER. +exec gh repo fork ${ORG}/${REPO} --fork-name ${FORK} +sleep 5 + +# Defer repo cleanup of fork +defer gh repo delete --yes ${USER}/${FORK} + +# Retrieve fork repository information +exec gh repo view ${USER}/${FORK} --json id --jq '.id' +stdout2env FORK_ID + +# Clone the fork +exec gh repo clone ${USER}/${FORK} +cd ${FORK} + +# Prepare a branch where changes are pulled from the upstream default branch but pushed to fork +exec git checkout -b feature-branch +exec git branch --set-upstream-to upstream/main +exec git commit --allow-empty -m 'Empty Commit' +exec git push origin feature-branch + +# Create the PR spanning upstream and fork repositories +exec gh pr create --title 'Feature Title' --body 'Feature Body' --head ${USER}:feature-branch +stdout https://${GH_HOST}/${ORG}/${REPO}/pull/1 + +# Assert that the PR was created with the correct head repository and refs +exec gh pr view ${USER}:feature-branch --json headRefName,headRepository,baseRefName,isCrossRepository +stdout {"baseRefName":"main","headRefName":"feature-branch","headRepository":{"id":"${FORK_ID}","name":"${FORK}"},"isCrossRepository":true} diff --git a/acceptance/testdata/pr/pr-create-without-upstream-config.txtar b/acceptance/testdata/pr/pr-create-without-upstream-config.txtar index 00f3535a775..e5a40af72a1 100644 --- a/acceptance/testdata/pr/pr-create-without-upstream-config.txtar +++ b/acceptance/testdata/pr/pr-create-without-upstream-config.txtar @@ -1,20 +1,22 @@ # This test is the same as pr-create-basic, except that the git push doesn't include the -u argument # This causes a git config read to fail during gh pr create, but it should not be fatal +env REPO=${SCRIPT_NAME}-${RANDOM_STRING} + # Use gh as a credential helper exec gh auth setup-git # Create a repository with a file so it has a default branch -exec gh repo create $ORG/$SCRIPT_NAME-$RANDOM_STRING --add-readme --private +exec gh repo create ${ORG}/${REPO} --add-readme --private # Defer repo cleanup -defer gh repo delete --yes $ORG/$SCRIPT_NAME-$RANDOM_STRING +defer gh repo delete --yes ${ORG}/${REPO} # Clone the repo -exec gh repo clone $ORG/$SCRIPT_NAME-$RANDOM_STRING +exec gh repo clone ${ORG}/${REPO} # Prepare a branch to PR -cd $SCRIPT_NAME-$RANDOM_STRING +cd ${REPO} exec git checkout -b feature-branch exec git commit --allow-empty -m 'Empty Commit' exec git push origin feature-branch diff --git a/acceptance/testdata/pr/pr-status-respects-cross-org.txtar b/acceptance/testdata/pr/pr-status-respects-cross-org.txtar new file mode 100644 index 00000000000..4505be92352 --- /dev/null +++ b/acceptance/testdata/pr/pr-status-respects-cross-org.txtar @@ -0,0 +1,46 @@ +skip 'it creates a fork owned by the user running the test' + +# Setup environment variables used for testscript +env REPO=${SCRIPT_NAME}-${RANDOM_STRING} +env FORK=${REPO}-fork + +# Use gh as a credential helper +exec gh auth setup-git + +# Get the current username for the fork owner +exec gh api user --jq .login +stdout2env USER + +# Create a repository to act as upstream with a file so it has a default branch +exec gh repo create ${ORG}/${REPO} --add-readme --private + +# Defer repo cleanup of upstream +defer gh repo delete --yes ${ORG}/${REPO} + +# Create a user fork of repository. This will be owned by USER. +exec gh repo fork ${ORG}/${REPO} --fork-name ${FORK} +sleep 5 + +# Defer repo cleanup of fork +defer gh repo delete --yes ${USER}/${FORK} + +# Retrieve fork repository information +exec gh repo view ${USER}/${FORK} --json id --jq '.id' +stdout2env FORK_ID + +# Clone the repo +exec gh repo clone ${USER}/${FORK} +cd ${FORK} + +# Prepare a branch where changes are pulled from the upstream default branch but pushed to fork +exec git checkout -b feature-branch +exec git commit --allow-empty -m 'Empty Commit' +exec git push -u origin feature-branch + +# Create the PR spanning upstream and fork repositories +exec gh pr create --title 'Feature Title' --body 'Feature Body' +stdout https://${GH_HOST}/${ORG}/${REPO}/pull/1 + +# Assert that the PR was created with the correct head repository and refs +exec gh pr status +! stdout 'There is no pull request associated with' diff --git a/acceptance/testdata/pr/pr-view-same-org-fork.txtar b/acceptance/testdata/pr/pr-view-same-org-fork.txtar index ca58918a911..eed524dec05 100644 --- a/acceptance/testdata/pr/pr-view-same-org-fork.txtar +++ b/acceptance/testdata/pr/pr-view-same-org-fork.txtar @@ -15,10 +15,11 @@ stdout2env REPO_ID # Create a fork in the same org exec gh repo fork ${ORG}/${REPO} --org ${ORG} --fork-name ${FORK} +sleep 5 # Defer repo cleanup of fork defer gh repo delete --yes ${ORG}/${FORK} -sleep 1 + exec gh repo view ${ORG}/${FORK} --json id --jq '.id' stdout2env FORK_ID diff --git a/acceptance/testdata/pr/pr-view-status-respects-branch-pushremote.txtar b/acceptance/testdata/pr/pr-view-status-respects-branch-pushremote.txtar index ef80cd8babf..4e1e5e64ac7 100644 --- a/acceptance/testdata/pr/pr-view-status-respects-branch-pushremote.txtar +++ b/acceptance/testdata/pr/pr-view-status-respects-branch-pushremote.txtar @@ -15,10 +15,11 @@ stdout2env REPO_ID # Create a user fork of repository as opposed to private organization fork exec gh repo fork ${ORG}/${REPO} --org ${ORG} --fork-name ${FORK} +sleep 5 # Defer repo cleanup of fork defer gh repo delete --yes ${ORG}/${FORK} -sleep 5 + exec gh repo view ${ORG}/${FORK} --json id --jq '.id' stdout2env FORK_ID @@ -27,7 +28,8 @@ exec gh repo clone ${ORG}/${FORK} cd ${FORK} # Prepare a branch where changes are pulled from the upstream default branch but pushed to fork -exec git checkout -b feature-branch upstream/main +exec git checkout -b feature-branch +exec git branch --set-upstream-to upstream/main exec git config branch.feature-branch.pushRemote origin exec git commit --allow-empty -m 'Empty Commit' exec git push diff --git a/acceptance/testdata/pr/pr-view-status-respects-remote-pushdefault.txtar b/acceptance/testdata/pr/pr-view-status-respects-remote-pushdefault.txtar index 8bfac28376a..6c0743a6f14 100644 --- a/acceptance/testdata/pr/pr-view-status-respects-remote-pushdefault.txtar +++ b/acceptance/testdata/pr/pr-view-status-respects-remote-pushdefault.txtar @@ -15,10 +15,11 @@ stdout2env REPO_ID # Create a user fork of repository as opposed to private organization fork exec gh repo fork ${ORG}/${REPO} --org ${ORG} --fork-name ${FORK} +sleep 5 # Defer repo cleanup of fork defer gh repo delete --yes ${ORG}/${FORK} -sleep 5 + exec gh repo view ${ORG}/${FORK} --json id --jq '.id' stdout2env FORK_ID @@ -27,7 +28,8 @@ exec gh repo clone ${ORG}/${FORK} cd ${FORK} # Prepare a branch where changes are pulled from the upstream default branch but pushed to fork -exec git checkout -b feature-branch upstream/main +exec git checkout -b feature-branch +exec git branch --set-upstream-to upstream/main exec git config remote.pushDefault origin exec git commit --allow-empty -m 'Empty Commit' exec git push diff --git a/acceptance/testdata/pr/pr-view-status-respects-simple-pushdefault.txtar b/acceptance/testdata/pr/pr-view-status-respects-simple-pushdefault.txtar index 114f401ecb6..b9621ea72cc 100644 --- a/acceptance/testdata/pr/pr-view-status-respects-simple-pushdefault.txtar +++ b/acceptance/testdata/pr/pr-view-status-respects-simple-pushdefault.txtar @@ -18,7 +18,8 @@ cd ${REPO} exec git config push.default simple # Prepare a branch where changes are pulled from the default branch instead of remote branch of same name -exec git checkout -b feature-branch origin/main +exec git checkout -b feature-branch +exec git branch --set-upstream-to origin/main # Create the PR exec git commit --allow-empty -m 'Empty Commit' diff --git a/acceptance/testdata/repo/repo-fork-sync.txtar b/acceptance/testdata/repo/repo-fork-sync.txtar index 6ed7b94e1a7..04c4c584555 100644 --- a/acceptance/testdata/repo/repo-fork-sync.txtar +++ b/acceptance/testdata/repo/repo-fork-sync.txtar @@ -9,13 +9,11 @@ defer gh repo delete --yes $ORG/$SCRIPT_NAME-$RANDOM_STRING # Fork and clone the repo exec gh repo fork $ORG/$SCRIPT_NAME-$RANDOM_STRING --org $ORG --fork-name $SCRIPT_NAME-$RANDOM_STRING-fork --clone +sleep 5 # Defer fork cleanup defer gh repo delete $ORG/$SCRIPT_NAME-$RANDOM_STRING-fork --yes -# Sleep so that the BE has time to sync -sleep 5 - # Check that the repo was forked exec gh repo view $ORG/$SCRIPT_NAME-$RANDOM_STRING-fork --json='isFork' --jq='.isFork' stdout 'true' diff --git a/acceptance/testdata/secret/secret-require-remote-disambiguation.txtar b/acceptance/testdata/secret/secret-require-remote-disambiguation.txtar index 02dec06a00d..f3fa4a47a0a 100644 --- a/acceptance/testdata/secret/secret-require-remote-disambiguation.txtar +++ b/acceptance/testdata/secret/secret-require-remote-disambiguation.txtar @@ -12,13 +12,11 @@ defer gh repo delete --yes ${ORG}/${REPO} # Create a fork exec gh repo fork ${ORG}/${REPO} --org ${ORG} --fork-name ${REPO}-fork +sleep 5 # Defer fork cleanup defer gh repo delete --yes ${ORG}/${REPO}-fork -# Sleep to allow the fork to be created before cloning -sleep 2 - # Clone and move into the fork repo exec gh repo clone ${ORG}/${REPO}-fork cd ${REPO}-fork diff --git a/api/queries_repo.go b/api/queries_repo.go index 53e6d879a47..27e21eb32ac 100644 --- a/api/queries_repo.go +++ b/api/queries_repo.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghinstance" "golang.org/x/sync/errgroup" @@ -782,35 +783,54 @@ func (m *RepoMetadataResult) projectV2TitleToID(projectTitle string) (string, bo return "", false } -func ProjectsToPaths(projects []RepoProject, projectsV2 []ProjectV2, names []string) ([]string, error) { - var paths []string - for _, projectName := range names { - found := false - for _, p := range projects { - if strings.EqualFold(projectName, p.Name) { - // format of ResourcePath: /OWNER/REPO/projects/PROJECT_NUMBER or /orgs/ORG/projects/PROJECT_NUMBER or /users/USER/projects/PROJECT_NUBER - // required format of path: OWNER/REPO/PROJECT_NUMBER or ORG/PROJECT_NUMBER or USER/PROJECT_NUMBER - var path string - pathParts := strings.Split(p.ResourcePath, "/") - if pathParts[1] == "orgs" || pathParts[1] == "users" { - path = fmt.Sprintf("%s/%s", pathParts[2], pathParts[4]) - } else { - path = fmt.Sprintf("%s/%s/%s", pathParts[1], pathParts[2], pathParts[4]) +func ProjectNamesToPaths(client *Client, repo ghrepo.Interface, projectNames []string, projectsV1Support gh.ProjectsV1Support) ([]string, error) { + paths := make([]string, 0, len(projectNames)) + matchedPaths := map[string]struct{}{} + + // TODO: ProjectsV1Cleanup + // At this point, we only know the names that the user has provided, so we can't push this conditional up the stack. + // First we'll try to match against v1 projects, if supported + if projectsV1Support == gh.ProjectsV1Supported { + v1Projects, err := v1Projects(client, repo) + if err != nil { + return nil, err + } + + for _, projectName := range projectNames { + for _, p := range v1Projects { + if strings.EqualFold(projectName, p.Name) { + pathParts := strings.Split(p.ResourcePath, "/") + var path string + if pathParts[1] == "orgs" || pathParts[1] == "users" { + path = fmt.Sprintf("%s/%s", pathParts[2], pathParts[4]) + } else { + path = fmt.Sprintf("%s/%s/%s", pathParts[1], pathParts[2], pathParts[4]) + } + paths = append(paths, path) + matchedPaths[projectName] = struct{}{} + break } - paths = append(paths, path) - found = true - break } } - if found { + } + + // Then we'll try to match against v2 projects + v2Projects, err := v2Projects(client, repo) + if err != nil { + return nil, err + } + + for _, projectName := range projectNames { + // If we already found a v1 project with this name, skip it + if _, ok := matchedPaths[projectName]; ok { continue } - for _, p := range projectsV2 { + + found := false + for _, p := range v2Projects { if strings.EqualFold(projectName, p.Title) { - // format of ResourcePath: /OWNER/REPO/projects/PROJECT_NUMBER or /orgs/ORG/projects/PROJECT_NUMBER or /users/USER/projects/PROJECT_NUBER - // required format of path: OWNER/REPO/PROJECT_NUMBER or ORG/PROJECT_NUMBER or USER/PROJECT_NUMBER - var path string pathParts := strings.Split(p.ResourcePath, "/") + var path string if pathParts[1] == "orgs" || pathParts[1] == "users" { path = fmt.Sprintf("%s/%s", pathParts[2], pathParts[4]) } else { @@ -821,10 +841,12 @@ func ProjectsToPaths(projects []RepoProject, projectsV2 []ProjectV2, names []str break } } + if !found { return nil, fmt.Errorf("'%s' not found", projectName) } } + return paths, nil } @@ -863,7 +885,8 @@ type RepoMetadataInput struct { Assignees bool Reviewers bool Labels bool - Projects bool + ProjectsV1 bool + ProjectsV2 bool Milestones bool } @@ -882,6 +905,7 @@ func RepoMetadata(client *Client, repo ghrepo.Interface, input RepoMetadataInput return err }) } + if input.Reviewers { g.Go(func() error { teams, err := OrganizationTeams(client, repo) @@ -894,6 +918,7 @@ func RepoMetadata(client *Client, repo ghrepo.Interface, input RepoMetadataInput return nil }) } + if input.Reviewers { g.Go(func() error { login, err := CurrentLoginName(client, repo.RepoHost()) @@ -904,6 +929,7 @@ func RepoMetadata(client *Client, repo ghrepo.Interface, input RepoMetadataInput return err }) } + if input.Labels { g.Go(func() error { labels, err := RepoLabels(client, repo) @@ -914,13 +940,23 @@ func RepoMetadata(client *Client, repo ghrepo.Interface, input RepoMetadataInput return err }) } - if input.Projects { + + if input.ProjectsV1 { + g.Go(func() error { + var err error + result.Projects, err = v1Projects(client, repo) + return err + }) + } + + if input.ProjectsV2 { g.Go(func() error { var err error - result.Projects, result.ProjectsV2, err = relevantProjects(client, repo) + result.ProjectsV2, err = v2Projects(client, repo) return err }) } + if input.Milestones { g.Go(func() error { milestones, err := RepoMilestones(client, repo, "open") @@ -943,7 +979,8 @@ type RepoResolveInput struct { Assignees []string Reviewers []string Labels []string - Projects []string + ProjectsV1 bool + ProjectsV2 bool Milestones []string } @@ -970,7 +1007,8 @@ func RepoResolveMetadataIDs(client *Client, repo ghrepo.Interface, input RepoRes // there is no way to look up projects nor milestones by name, so preload them all mi := RepoMetadataInput{ - Projects: len(input.Projects) > 0, + ProjectsV1: input.ProjectsV1, + ProjectsV2: input.ProjectsV2, Milestones: len(input.Milestones) > 0, } result, err := RepoMetadata(client, repo, mi) @@ -1237,26 +1275,12 @@ func RepoMilestones(client *Client, repo ghrepo.Interface, state string) ([]Repo return milestones, nil } -func ProjectNamesToPaths(client *Client, repo ghrepo.Interface, projectNames []string) ([]string, error) { - projects, projectsV2, err := relevantProjects(client, repo) - if err != nil { - return nil, err - } - return ProjectsToPaths(projects, projectsV2, projectNames) -} - -// RelevantProjects retrieves set of Projects and ProjectsV2 relevant to given repository: +// v1Projects retrieves set of RepoProjects relevant to given repository: // - Projects for repository // - Projects for repository organization, if it belongs to one -// - ProjectsV2 owned by current user -// - ProjectsV2 linked to repository -// - ProjectsV2 owned by repository organization, if it belongs to one -func relevantProjects(client *Client, repo ghrepo.Interface) ([]RepoProject, []ProjectV2, error) { +func v1Projects(client *Client, repo ghrepo.Interface) ([]RepoProject, error) { var repoProjects []RepoProject var orgProjects []RepoProject - var userProjectsV2 []ProjectV2 - var repoProjectsV2 []ProjectV2 - var orgProjectsV2 []ProjectV2 g, _ := errgroup.WithContext(context.Background()) @@ -1268,6 +1292,7 @@ func relevantProjects(client *Client, repo ghrepo.Interface) ([]RepoProject, []P } return err }) + g.Go(func() error { var err error orgProjects, err = OrganizationProjects(client, repo) @@ -1277,6 +1302,29 @@ func relevantProjects(client *Client, repo ghrepo.Interface) ([]RepoProject, []P } return nil }) + + if err := g.Wait(); err != nil { + return nil, err + } + + projects := make([]RepoProject, 0, len(repoProjects)+len(orgProjects)) + projects = append(projects, repoProjects...) + projects = append(projects, orgProjects...) + + return projects, nil +} + +// v2Projects retrieves set of ProjectV2 relevant to given repository: +// - ProjectsV2 owned by current user +// - ProjectsV2 linked to repository +// - ProjectsV2 owned by repository organization, if it belongs to one +func v2Projects(client *Client, repo ghrepo.Interface) ([]ProjectV2, error) { + var userProjectsV2 []ProjectV2 + var repoProjectsV2 []ProjectV2 + var orgProjectsV2 []ProjectV2 + + g, _ := errgroup.WithContext(context.Background()) + g.Go(func() error { var err error userProjectsV2, err = CurrentUserProjectsV2(client, repo.RepoHost()) @@ -1286,6 +1334,7 @@ func relevantProjects(client *Client, repo ghrepo.Interface) ([]RepoProject, []P } return nil }) + g.Go(func() error { var err error repoProjectsV2, err = RepoProjectsV2(client, repo) @@ -1295,6 +1344,7 @@ func relevantProjects(client *Client, repo ghrepo.Interface) ([]RepoProject, []P } return nil }) + g.Go(func() error { var err error orgProjectsV2, err = OrganizationProjectsV2(client, repo) @@ -1308,13 +1358,9 @@ func relevantProjects(client *Client, repo ghrepo.Interface) ([]RepoProject, []P }) if err := g.Wait(); err != nil { - return nil, nil, err + return nil, err } - projects := make([]RepoProject, 0, len(repoProjects)+len(orgProjects)) - projects = append(projects, repoProjects...) - projects = append(projects, orgProjects...) - // ProjectV2 might appear across multiple queries so use a map to keep them deduplicated. m := make(map[string]ProjectV2, len(userProjectsV2)+len(repoProjectsV2)+len(orgProjectsV2)) for _, p := range userProjectsV2 { @@ -1331,7 +1377,7 @@ func relevantProjects(client *Client, repo ghrepo.Interface) ([]RepoProject, []P projectsV2 = append(projectsV2, p) } - return projects, projectsV2, nil + return projectsV2, nil } func CreateRepoTransformToV4(apiClient *Client, hostname string, method string, path string, body io.Reader) (*Repository, error) { diff --git a/api/queries_repo_test.go b/api/queries_repo_test.go index 13aee459a1e..72ed357760b 100644 --- a/api/queries_repo_test.go +++ b/api/queries_repo_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/MakeNowJust/heredoc" + "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/pkg/httpmock" "github.com/stretchr/testify/assert" @@ -44,7 +45,8 @@ func Test_RepoMetadata(t *testing.T) { Assignees: true, Reviewers: true, Labels: true, - Projects: true, + ProjectsV1: true, + ProjectsV2: true, Milestones: true, } @@ -211,37 +213,16 @@ func Test_RepoMetadata(t *testing.T) { } } -func Test_ProjectsToPaths(t *testing.T) { - expectedProjectPaths := []string{"OWNER/REPO/PROJECT_NUMBER", "ORG/PROJECT_NUMBER", "OWNER/REPO/PROJECT_NUMBER_2"} - projects := []RepoProject{ - {ID: "id1", Name: "My Project", ResourcePath: "/OWNER/REPO/projects/PROJECT_NUMBER"}, - {ID: "id2", Name: "Org Project", ResourcePath: "/orgs/ORG/projects/PROJECT_NUMBER"}, - {ID: "id3", Name: "Project", ResourcePath: "/orgs/ORG/projects/PROJECT_NUMBER_2"}, - } - projectsV2 := []ProjectV2{ - {ID: "id4", Title: "My Project V2", ResourcePath: "/OWNER/REPO/projects/PROJECT_NUMBER_2"}, - {ID: "id5", Title: "Org Project V2", ResourcePath: "/orgs/ORG/projects/PROJECT_NUMBER_3"}, - } - projectNames := []string{"My Project", "Org Project", "My Project V2"} - - projectPaths, err := ProjectsToPaths(projects, projectsV2, projectNames) - if err != nil { - t.Errorf("error resolving projects: %v", err) - } - if !sliceEqual(projectPaths, expectedProjectPaths) { - t.Errorf("expected projects %v, got %v", expectedProjectPaths, projectPaths) - } -} - func Test_ProjectNamesToPaths(t *testing.T) { - http := &httpmock.Registry{} - client := newTestClient(http) + t.Run("when projectsV1 is supported, requests them", func(t *testing.T) { + http := &httpmock.Registry{} + client := newTestClient(http) - repo, _ := ghrepo.FromFullName("OWNER/REPO") + repo, _ := ghrepo.FromFullName("OWNER/REPO") - http.Register( - httpmock.GraphQL(`query RepositoryProjectList\b`), - httpmock.StringResponse(` + http.Register( + httpmock.GraphQL(`query RepositoryProjectList\b`), + httpmock.StringResponse(` { "data": { "repository": { "projects": { "nodes": [ { "name": "Cleanup", "id": "CLEANUPID", "resourcePath": "/OWNER/REPO/projects/1" }, @@ -250,9 +231,9 @@ func Test_ProjectNamesToPaths(t *testing.T) { "pageInfo": { "hasNextPage": false } } } } } `)) - http.Register( - httpmock.GraphQL(`query OrganizationProjectList\b`), - httpmock.StringResponse(` + http.Register( + httpmock.GraphQL(`query OrganizationProjectList\b`), + httpmock.StringResponse(` { "data": { "organization": { "projects": { "nodes": [ { "name": "Triage", "id": "TRIAGEID", "resourcePath": "/orgs/ORG/projects/1" } @@ -260,9 +241,9 @@ func Test_ProjectNamesToPaths(t *testing.T) { "pageInfo": { "hasNextPage": false } } } } } `)) - http.Register( - httpmock.GraphQL(`query RepositoryProjectV2List\b`), - httpmock.StringResponse(` + http.Register( + httpmock.GraphQL(`query RepositoryProjectV2List\b`), + httpmock.StringResponse(` { "data": { "repository": { "projectsV2": { "nodes": [ { "title": "CleanupV2", "id": "CLEANUPV2ID", "resourcePath": "/OWNER/REPO/projects/3" }, @@ -271,9 +252,9 @@ func Test_ProjectNamesToPaths(t *testing.T) { "pageInfo": { "hasNextPage": false } } } } } `)) - http.Register( - httpmock.GraphQL(`query OrganizationProjectV2List\b`), - httpmock.StringResponse(` + http.Register( + httpmock.GraphQL(`query OrganizationProjectV2List\b`), + httpmock.StringResponse(` { "data": { "organization": { "projectsV2": { "nodes": [ { "title": "TriageV2", "id": "TRIAGEV2ID", "resourcePath": "/orgs/ORG/projects/2" } @@ -281,9 +262,9 @@ func Test_ProjectNamesToPaths(t *testing.T) { "pageInfo": { "hasNextPage": false } } } } } `)) - http.Register( - httpmock.GraphQL(`query UserProjectV2List\b`), - httpmock.StringResponse(` + http.Register( + httpmock.GraphQL(`query UserProjectV2List\b`), + httpmock.StringResponse(` { "data": { "viewer": { "projectsV2": { "nodes": [ { "title": "MonalisaV2", "id": "MONALISAV2ID", "resourcePath": "/users/MONALISA/projects/5" } @@ -292,15 +273,110 @@ func Test_ProjectNamesToPaths(t *testing.T) { } } } } `)) - projectPaths, err := ProjectNamesToPaths(client, repo, []string{"Triage", "Roadmap", "TriageV2", "RoadmapV2", "MonalisaV2"}) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + projectPaths, err := ProjectNamesToPaths(client, repo, []string{"Triage", "Roadmap", "TriageV2", "RoadmapV2", "MonalisaV2"}, gh.ProjectsV1Supported) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } - expectedProjectPaths := []string{"ORG/1", "OWNER/REPO/2", "ORG/2", "OWNER/REPO/4", "MONALISA/5"} - if !sliceEqual(projectPaths, expectedProjectPaths) { - t.Errorf("expected projects paths %v, got %v", expectedProjectPaths, projectPaths) - } + expectedProjectPaths := []string{"ORG/1", "OWNER/REPO/2", "ORG/2", "OWNER/REPO/4", "MONALISA/5"} + if !sliceEqual(projectPaths, expectedProjectPaths) { + t.Errorf("expected projects paths %v, got %v", expectedProjectPaths, projectPaths) + } + }) + + t.Run("when projectsV1 is not supported, does not request them", func(t *testing.T) { + http := &httpmock.Registry{} + client := newTestClient(http) + + repo, _ := ghrepo.FromFullName("OWNER/REPO") + + http.Exclude( + t, + httpmock.GraphQL(`query RepositoryProjectList\b`), + ) + http.Exclude( + t, + httpmock.GraphQL(`query OrganizationProjectList\b`), + ) + + http.Register( + httpmock.GraphQL(`query RepositoryProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projectsV2": { + "nodes": [ + { "title": "CleanupV2", "id": "CLEANUPV2ID", "resourcePath": "/OWNER/REPO/projects/3" }, + { "title": "RoadmapV2", "id": "ROADMAPV2ID", "resourcePath": "/OWNER/REPO/projects/4" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + http.Register( + httpmock.GraphQL(`query OrganizationProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projectsV2": { + "nodes": [ + { "title": "TriageV2", "id": "TRIAGEV2ID", "resourcePath": "/orgs/ORG/projects/2" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + http.Register( + httpmock.GraphQL(`query UserProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "viewer": { "projectsV2": { + "nodes": [ + { "title": "MonalisaV2", "id": "MONALISAV2ID", "resourcePath": "/users/MONALISA/projects/5" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + + projectPaths, err := ProjectNamesToPaths(client, repo, []string{"TriageV2", "RoadmapV2", "MonalisaV2"}, gh.ProjectsV1Unsupported) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expectedProjectPaths := []string{"ORG/2", "OWNER/REPO/4", "MONALISA/5"} + if !sliceEqual(projectPaths, expectedProjectPaths) { + t.Errorf("expected projects paths %v, got %v", expectedProjectPaths, projectPaths) + } + }) + + t.Run("when a project is not found, returns an error", func(t *testing.T) { + http := &httpmock.Registry{} + client := newTestClient(http) + + repo, _ := ghrepo.FromFullName("OWNER/REPO") + + // No projects found + http.Register( + httpmock.GraphQL(`query RepositoryProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projectsV2": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + http.Register( + httpmock.GraphQL(`query OrganizationProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projectsV2": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + http.Register( + httpmock.GraphQL(`query UserProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "viewer": { "projectsV2": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + + _, err := ProjectNamesToPaths(client, repo, []string{"TriageV2"}, gh.ProjectsV1Unsupported) + require.Equal(t, err, fmt.Errorf("'TriageV2' not found")) + }) } func Test_RepoResolveMetadataIDs(t *testing.T) { diff --git a/git/client.go b/git/client.go index 11a2e2e20a6..fe2819cf0d4 100644 --- a/git/client.go +++ b/git/client.go @@ -381,7 +381,6 @@ func (c *Client) lookupCommit(ctx context.Context, sha, format string) ([]byte, // Downstream consumers of ReadBranchConfig should consider the behavior they desire if this errors, // as an empty config is not necessarily breaking. func (c *Client) ReadBranchConfig(ctx context.Context, branch string) (BranchConfig, error) { - prefix := regexp.QuoteMeta(fmt.Sprintf("branch.%s.", branch)) args := []string{"config", "--get-regexp", fmt.Sprintf("^%s(remote|merge|pushremote|%s)$", prefix, MergeBaseConfig)} cmd, err := c.Command(ctx, args...) @@ -441,18 +440,50 @@ func (c *Client) SetBranchConfig(ctx context.Context, branch, name, value string return err } +// PushDefault defines the action git push should take if no refspec is given. +// See: https://git-scm.com/docs/git-config#Documentation/git-config.txt-pushdefault +type PushDefault string + +const ( + PushDefaultNothing PushDefault = "nothing" + PushDefaultCurrent PushDefault = "current" + PushDefaultUpstream PushDefault = "upstream" + PushDefaultTracking PushDefault = "tracking" + PushDefaultSimple PushDefault = "simple" + PushDefaultMatching PushDefault = "matching" +) + +func ParsePushDefault(s string) (PushDefault, error) { + validPushDefaults := map[string]struct{}{ + string(PushDefaultNothing): {}, + string(PushDefaultCurrent): {}, + string(PushDefaultUpstream): {}, + string(PushDefaultTracking): {}, + string(PushDefaultSimple): {}, + string(PushDefaultMatching): {}, + } + + if _, ok := validPushDefaults[s]; ok { + return PushDefault(s), nil + } + + return "", fmt.Errorf("unknown push.default value: %s", s) +} + // PushDefault returns the value of push.default in the config. If the value // is not set, it returns "simple" (the default git value). See // https://git-scm.com/docs/git-config#Documentation/git-config.txt-pushdefault -func (c *Client) PushDefault(ctx context.Context) (string, error) { +func (c *Client) PushDefault(ctx context.Context) (PushDefault, error) { pushDefault, err := c.Config(ctx, "push.default") if err == nil { - return pushDefault, nil + return ParsePushDefault(pushDefault) } + // If there is an error that the config key is not set, return the default value + // that git uses since 2.0. var gitError *GitError if ok := errors.As(err, &gitError); ok && gitError.ExitCode == 1 { - return "simple", nil + return PushDefaultSimple, nil } return "", err } @@ -473,13 +504,48 @@ func (c *Client) RemotePushDefault(ctx context.Context) (string, error) { return "", err } -// ParsePushRevision gets the value of the @{push} revision syntax +// RemoteTrackingRef is the structured form of the string "refs/remotes//". +// For example, the @{push} revision syntax could report "refs/remotes/origin/main" which would +// be parsed into RemoteTrackingRef{Remote: "origin", Branch: "main"}. +type RemoteTrackingRef struct { + Remote string + Branch string +} + +func (r RemoteTrackingRef) String() string { + return fmt.Sprintf("refs/remotes/%s/%s", r.Remote, r.Branch) +} + +// ParseRemoteTrackingRef parses a string of the form "refs/remotes//" into +// a RemoteTrackingBranch struct. If the string does not match this format, an error is returned. +func ParseRemoteTrackingRef(s string) (RemoteTrackingRef, error) { + parts := strings.Split(s, "/") + if len(parts) != 4 || parts[0] != "refs" || parts[1] != "remotes" { + return RemoteTrackingRef{}, fmt.Errorf("remote tracking branch must have format refs/remotes// but was: %s", s) + } + + return RemoteTrackingRef{ + Remote: parts[2], + Branch: parts[3], + }, nil +} + +// PushRevision gets the value of the @{push} revision syntax // An error here doesn't necessarily mean something is broken, but may mean that the @{push} // revision syntax couldn't be resolved, such as in non-centralized workflows with // push.default = simple. Downstream consumers should consider how to handle this error. -func (c *Client) ParsePushRevision(ctx context.Context, branch string) (string, error) { - revParseOut, err := c.revParse(ctx, "--abbrev-ref", branch+"@{push}") - return firstLine(revParseOut), err +func (c *Client) PushRevision(ctx context.Context, branch string) (RemoteTrackingRef, error) { + revParseOut, err := c.revParse(ctx, "--symbolic-full-name", branch+"@{push}") + if err != nil { + return RemoteTrackingRef{}, err + } + + ref, err := ParseRemoteTrackingRef(firstLine(revParseOut)) + if err != nil { + return RemoteTrackingRef{}, fmt.Errorf("could not parse push revision: %v", err) + } + + return ref, nil } func (c *Client) DeleteLocalTag(ctx context.Context, tag string) error { diff --git a/git/client_test.go b/git/client_test.go index 9fa076199e5..3d7560228be 100644 --- a/git/client_test.go +++ b/git/client_test.go @@ -952,7 +952,7 @@ func TestClientPushDefault(t *testing.T) { tests := []struct { name string commandResult commandResult - wantPushDefault string + wantPushDefault PushDefault wantError *GitError }{ { @@ -961,7 +961,7 @@ func TestClientPushDefault(t *testing.T) { ExitStatus: 1, Stderr: "error: key does not contain a section: remote.pushDefault", }, - wantPushDefault: "simple", + wantPushDefault: PushDefaultSimple, wantError: nil, }, { @@ -970,7 +970,7 @@ func TestClientPushDefault(t *testing.T) { ExitStatus: 0, Stdout: "current", }, - wantPushDefault: "current", + wantPushDefault: PushDefaultCurrent, wantError: nil, }, { @@ -1077,17 +1077,17 @@ func TestClientParsePushRevision(t *testing.T) { name string branch string commandResult commandResult - wantParsedPushRevision string - wantError *GitError + wantParsedPushRevision RemoteTrackingRef + wantError error }{ { - name: "@{push} resolves to origin/branchName", + name: "@{push} resolves to refs/remotes/origin/branchName", branch: "branchName", commandResult: commandResult{ ExitStatus: 0, - Stdout: "origin/branchName", + Stdout: "refs/remotes/origin/branchName", }, - wantParsedPushRevision: "origin/branchName", + wantParsedPushRevision: RemoteTrackingRef{Remote: "origin", Branch: "branchName"}, }, { name: "@{push} doesn't resolve", @@ -1095,16 +1095,25 @@ func TestClientParsePushRevision(t *testing.T) { ExitStatus: 128, Stderr: "fatal: git error", }, - wantParsedPushRevision: "", + wantParsedPushRevision: RemoteTrackingRef{}, wantError: &GitError{ ExitCode: 128, Stderr: "fatal: git error", }, }, + { + name: "@{push} resolves to something surprising", + commandResult: commandResult{ + ExitStatus: 0, + Stdout: "not/a/valid/remote/ref", + }, + wantParsedPushRevision: RemoteTrackingRef{}, + wantError: fmt.Errorf("could not parse push revision: remote tracking branch must have format refs/remotes// but was: not/a/valid/remote/ref"), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cmd := fmt.Sprintf("path/to/git rev-parse --abbrev-ref %s@{push}", tt.branch) + cmd := fmt.Sprintf("path/to/git rev-parse --symbolic-full-name %s@{push}", tt.branch) cmdCtx := createMockedCommandContext(t, mockedCommands{ args(cmd): tt.commandResult, }) @@ -1112,20 +1121,91 @@ func TestClientParsePushRevision(t *testing.T) { GitPath: "path/to/git", commandContext: cmdCtx, } - pushDefault, err := client.ParsePushRevision(context.Background(), tt.branch) + trackingRef, err := client.PushRevision(context.Background(), tt.branch) if tt.wantError != nil { - var gitError *GitError - require.ErrorAs(t, err, &gitError) - assert.Equal(t, tt.wantError.ExitCode, gitError.ExitCode) - assert.Equal(t, tt.wantError.Stderr, gitError.Stderr) + var wantErrorAsGit *GitError + if errors.As(err, &wantErrorAsGit) { + var gitError *GitError + require.ErrorAs(t, err, &gitError) + assert.Equal(t, wantErrorAsGit.ExitCode, gitError.ExitCode) + assert.Equal(t, wantErrorAsGit.Stderr, gitError.Stderr) + } else { + assert.Equal(t, err, tt.wantError) + } } else { require.NoError(t, err) } - assert.Equal(t, tt.wantParsedPushRevision, pushDefault) + assert.Equal(t, tt.wantParsedPushRevision, trackingRef) }) } } +func TestRemoteTrackingRef(t *testing.T) { + t.Run("parsing", func(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + remoteTrackingRef string + wantRemoteTrackingRef RemoteTrackingRef + wantError error + }{ + { + name: "valid remote tracking ref", + remoteTrackingRef: "refs/remotes/origin/branchName", + wantRemoteTrackingRef: RemoteTrackingRef{ + Remote: "origin", + Branch: "branchName", + }, + }, + { + name: "incorrect parts", + remoteTrackingRef: "refs/remotes/origin", + wantRemoteTrackingRef: RemoteTrackingRef{}, + wantError: fmt.Errorf("remote tracking branch must have format refs/remotes// but was: refs/remotes/origin"), + }, + { + name: "incorrect prefix type", + remoteTrackingRef: "invalid/remotes/origin/branchName", + wantRemoteTrackingRef: RemoteTrackingRef{}, + wantError: fmt.Errorf("remote tracking branch must have format refs/remotes// but was: invalid/remotes/origin/branchName"), + }, + { + name: "incorrect ref type", + remoteTrackingRef: "refs/invalid/origin/branchName", + wantRemoteTrackingRef: RemoteTrackingRef{}, + wantError: fmt.Errorf("remote tracking branch must have format refs/remotes// but was: refs/invalid/origin/branchName"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + trackingRef, err := ParseRemoteTrackingRef(tt.remoteTrackingRef) + if tt.wantError != nil { + require.Equal(t, tt.wantError, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.wantRemoteTrackingRef, trackingRef) + }) + } + }) + + t.Run("stringifying", func(t *testing.T) { + t.Parallel() + + remoteTrackingRef := RemoteTrackingRef{ + Remote: "origin", + Branch: "branchName", + } + + require.Equal(t, "refs/remotes/origin/branchName", remoteTrackingRef.String()) + }) +} + func TestClientDeleteLocalTag(t *testing.T) { tests := []struct { name string @@ -1992,6 +2072,41 @@ func TestCredentialPatternFromHost(t *testing.T) { } } +func TestPushDefault(t *testing.T) { + t.Run("it parses valid values correctly", func(t *testing.T) { + t.Parallel() + + tests := []struct { + value string + expectedPushDefault PushDefault + }{ + {"nothing", PushDefaultNothing}, + {"current", PushDefaultCurrent}, + {"upstream", PushDefaultUpstream}, + {"tracking", PushDefaultTracking}, + {"simple", PushDefaultSimple}, + {"matching", PushDefaultMatching}, + } + + for _, test := range tests { + t.Run(test.value, func(t *testing.T) { + t.Parallel() + + pushDefault, err := ParsePushDefault(test.value) + require.NoError(t, err) + assert.Equal(t, test.expectedPushDefault, pushDefault) + }) + } + }) + + t.Run("it returns an error for invalid values", func(t *testing.T) { + t.Parallel() + + _, err := ParsePushDefault("invalid") + require.Error(t, err) + }) +} + func createCommandContext(t *testing.T, exitStatus int, stdout, stderr string) (*exec.Cmd, commandCtx) { cmd := exec.CommandContext(context.Background(), os.Args[0], "-test.run=TestHelperProcess", "--") cmd.Env = []string{ diff --git a/git/command.go b/git/command.go index 8065ffd86be..c4614d086b4 100644 --- a/git/command.go +++ b/git/command.go @@ -43,10 +43,21 @@ func (gc *Command) Output() ([]byte, error) { out, err := run.PrepareCmd(gc.Cmd).Output() if err != nil { ge := GitError{err: err} + + // In real implementation, this should be an exec.ExitError, as below, + // but the tests use a different type because exec.ExitError are difficult + // to create. We want to get the exit code and stderr, but stderr + // is not a method and so tests can't access it. + // THIS MEANS THAT TESTS WILL NOT CORRECTLY HAVE STDERR SET, + // but at least tests can get the exit code. + var exitErrorWithExitCode errWithExitCode + if errors.As(err, &exitErrorWithExitCode) { + ge.ExitCode = exitErrorWithExitCode.ExitCode() + } + var exitError *exec.ExitError if errors.As(err, &exitError) { ge.Stderr = string(exitError.Stderr) - ge.ExitCode = exitError.ExitCode() } err = &ge } diff --git a/go.mod b/go.mod index e0f76dce842..31b07f2cf44 100644 --- a/go.mod +++ b/go.mod @@ -44,7 +44,7 @@ require ( github.com/rivo/tview v0.0.0-20221029100920-c4a7e501810d github.com/shurcooL/githubv4 v0.0.0-20240120211514-18a1ae0e79dc github.com/sigstore/protobuf-specs v0.4.1 - github.com/sigstore/sigstore-go v0.7.1 + github.com/sigstore/sigstore-go v0.7.2 github.com/spf13/cobra v1.9.1 github.com/spf13/pflag v1.0.6 github.com/stretchr/testify v1.10.0 @@ -53,7 +53,7 @@ require ( golang.org/x/sync v0.13.0 golang.org/x/term v0.31.0 golang.org/x/text v0.24.0 - google.golang.org/grpc v1.71.0 + google.golang.org/grpc v1.71.1 google.golang.org/protobuf v1.36.6 gopkg.in/h2non/gock.v1 v1.1.2 gopkg.in/yaml.v3 v3.0.1 diff --git a/go.sum b/go.sum index 5441c96c553..b312bcf6c5d 100644 --- a/go.sum +++ b/go.sum @@ -452,8 +452,8 @@ github.com/sigstore/rekor v1.3.9 h1:sUjRpKVh/hhgqGMs0t+TubgYsksArZ6poLEC3MsGAzU= github.com/sigstore/rekor v1.3.9/go.mod h1:xThNUhm6eNEmkJ/SiU/FVU7pLY2f380fSDZFsdDWlcM= github.com/sigstore/sigstore v1.9.1 h1:bNMsfFATsMPaagcf+uppLk4C9rQZ2dh5ysmCxQBYWaw= github.com/sigstore/sigstore v1.9.1/go.mod h1:zUoATYzR1J3rLNp3jmp4fzIJtWdhC3ZM6MnpcBtnsE4= -github.com/sigstore/sigstore-go v0.7.1 h1:lyzi3AjO6+BHc5zCf9fniycqPYOt3RaC08M/FRmQhVY= -github.com/sigstore/sigstore-go v0.7.1/go.mod h1:AIRj4I3LC82qd07VFm3T2zXYiddxeBV1k/eoS8nTz0E= +github.com/sigstore/sigstore-go v0.7.2 h1:CN4xPasChSEb0QBMxMW5dLcXdA9KD4QiRyVnMkhXj6U= +github.com/sigstore/sigstore-go v0.7.2/go.mod h1:AIRj4I3LC82qd07VFm3T2zXYiddxeBV1k/eoS8nTz0E= github.com/sigstore/sigstore/pkg/signature/kms/aws v1.9.1 h1:/YcNq687WnXpIRXl04nLfJX741G4iW+w+7Nem2Zy0f4= github.com/sigstore/sigstore/pkg/signature/kms/aws v1.9.1/go.mod h1:ApL9RpKsi7gkSYN0bMNdm/3jZ9EefxMmfYHfUmq2ZYM= github.com/sigstore/sigstore/pkg/signature/kms/azure v1.9.1 h1:FnusXyTIInnwfIOzzl5PFilRm1I97dxMSOcCkZBu9Kc= @@ -601,8 +601,8 @@ google.golang.org/genproto/googleapis/api v0.0.0-20250303144028-a0af3efb3deb h1: google.golang.org/genproto/googleapis/api v0.0.0-20250303144028-a0af3efb3deb/go.mod h1:jbe3Bkdp+Dh2IrslsFCklNhweNTBgSYanP1UXhJDhKg= google.golang.org/genproto/googleapis/rpc v0.0.0-20250313205543-e70fdf4c4cb4 h1:iK2jbkWL86DXjEx0qiHcRE9dE4/Ahua5k6V8OWFb//c= google.golang.org/genproto/googleapis/rpc v0.0.0-20250313205543-e70fdf4c4cb4/go.mod h1:LuRYeWDFV6WOn90g357N17oMCaxpgCnbi/44qJvDn2I= -google.golang.org/grpc v1.71.0 h1:kF77BGdPTQ4/JZWMlb9VpJ5pa25aqvVqogsxNHHdeBg= -google.golang.org/grpc v1.71.0/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= +google.golang.org/grpc v1.71.1 h1:ffsFWr7ygTUscGPI0KKK6TLrGz0476KUvvsbqWK0rPI= +google.golang.org/grpc v1.71.1/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/config/config.go b/internal/config/config.go index e7534dfdb5f..003a0ca171e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -14,18 +14,23 @@ import ( ghConfig "github.com/cli/go-gh/v2/pkg/config" ) +// Important: some of the following configuration settings are used outside of `cli/cli`, +// they are defined here to avoid `cli/cli` being changed unexpectedly. const ( + accessibleColorsKey = "accessible_colors" // used by cli/go-gh to enable the use of customizable, accessible 4-bit colors. + accessiblePrompterKey = "accessible_prompter" aliasesKey = "aliases" - browserKey = "browser" + browserKey = "browser" // used by cli/go-gh to open URLs in web browsers colorLabelsKey = "color_labels" - editorKey = "editor" + editorKey = "editor" // used by cli/go-gh to open interactive text editor gitProtocolKey = "git_protocol" - hostsKey = "hosts" + hostsKey = "hosts" // used by cli/go-gh to locate authenticated host tokens httpUnixSocketKey = "http_unix_socket" - oauthTokenKey = "oauth_token" + oauthTokenKey = "oauth_token" // used by cli/go-gh to locate authenticated host tokens pagerKey = "pager" promptKey = "prompt" preferEditorPromptKey = "prefer_editor_prompt" + spinnerKey = "spinner" userKey = "user" usersKey = "users" versionKey = "version" @@ -109,6 +114,16 @@ func (c *cfg) Authentication() gh.AuthConfig { return &AuthConfig{cfg: c.cfg} } +func (c *cfg) AccessibleColors(hostname string) gh.ConfigEntry { + // Intentionally panic if there is no user provided value or default value (which would be a programmer error) + return c.GetOrDefault(hostname, accessibleColorsKey).Unwrap() +} + +func (c *cfg) AccessiblePrompter(hostname string) gh.ConfigEntry { + // Intentionally panic if there is no user provided value or default value (which would be a programmer error) + return c.GetOrDefault(hostname, accessiblePrompterKey).Unwrap() +} + func (c *cfg) Browser(hostname string) gh.ConfigEntry { // Intentionally panic if there is no user provided value or default value (which would be a programmer error) return c.GetOrDefault(hostname, browserKey).Unwrap() @@ -149,6 +164,11 @@ func (c *cfg) PreferEditorPrompt(hostname string) gh.ConfigEntry { return c.GetOrDefault(hostname, preferEditorPromptKey).Unwrap() } +func (c *cfg) Spinner(hostname string) gh.ConfigEntry { + // Intentionally panic if there is no user provided value or default value (which would be a programmer error) + return c.GetOrDefault(hostname, spinnerKey).Unwrap() +} + func (c *cfg) Version() o.Option[string] { return c.get("", versionKey) } @@ -540,6 +560,12 @@ http_unix_socket: browser: # Whether to display labels using their RGB hex color codes in terminals that support truecolor. Supported values: enabled, disabled color_labels: disabled +# Whether customizable, 4-bit accessible colors should be used. Supported values: enabled, disabled +accessible_colors: disabled +# Whether an accessible prompter should be used. Supported values: enabled, disabled +accessible_prompter: disabled +# Whether to use a animated spinner as a progress indicator. If disabled, a textual progress indicator is used instead. Supported values: enabled, disabled +spinner: enabled ` type ConfigOption struct { @@ -619,6 +645,33 @@ var Options = []ConfigOption{ return c.ColorLabels(hostname).Value }, }, + { + Key: accessibleColorsKey, + Description: "whether customizable, 4-bit accessible colors should be used", + DefaultValue: "disabled", + AllowedValues: []string{"enabled", "disabled"}, + CurrentValue: func(c gh.Config, hostname string) string { + return c.AccessibleColors(hostname).Value + }, + }, + { + Key: accessiblePrompterKey, + Description: "whether an accessible prompter should be used", + DefaultValue: "disabled", + AllowedValues: []string{"enabled", "disabled"}, + CurrentValue: func(c gh.Config, hostname string) string { + return c.AccessiblePrompter(hostname).Value + }, + }, + { + Key: spinnerKey, + Description: "whether to use a animated spinner as a progress indicator", + DefaultValue: "enabled", + AllowedValues: []string{"enabled", "disabled"}, + CurrentValue: func(c gh.Config, hostname string) string { + return c.Spinner(hostname).Value + }, + }, } func HomeDirPath(subdir string) (string, error) { diff --git a/internal/config/stub.go b/internal/config/stub.go index 78073da4a17..ea60254db85 100644 --- a/internal/config/stub.go +++ b/internal/config/stub.go @@ -52,6 +52,12 @@ func NewFromString(cfgStr string) *ghmock.ConfigMock { }, } } + mock.AccessibleColorsFunc = func(hostname string) gh.ConfigEntry { + return cfg.AccessibleColors(hostname) + } + mock.AccessiblePrompterFunc = func(hostname string) gh.ConfigEntry { + return cfg.AccessiblePrompter(hostname) + } mock.BrowserFunc = func(hostname string) gh.ConfigEntry { return cfg.Browser(hostname) } @@ -76,6 +82,9 @@ func NewFromString(cfgStr string) *ghmock.ConfigMock { mock.PreferEditorPromptFunc = func(hostname string) gh.ConfigEntry { return cfg.PreferEditorPrompt(hostname) } + mock.SpinnerFunc = func(hostname string) gh.ConfigEntry { + return cfg.Spinner(hostname) + } mock.VersionFunc = func() o.Option[string] { return cfg.Version() } diff --git a/internal/featuredetection/detector_mock.go b/internal/featuredetection/detector_mock.go index 6f36dd3fc03..6f760f20949 100644 --- a/internal/featuredetection/detector_mock.go +++ b/internal/featuredetection/detector_mock.go @@ -1,5 +1,7 @@ package featuredetection +import "github.com/cli/cli/v2/internal/gh" + type DisabledDetectorMock struct{} func (md *DisabledDetectorMock) IssueFeatures() (IssueFeatures, error) { @@ -14,6 +16,10 @@ func (md *DisabledDetectorMock) RepositoryFeatures() (RepositoryFeatures, error) return RepositoryFeatures{}, nil } +func (md *DisabledDetectorMock) ProjectsV1() gh.ProjectsV1Support { + return gh.ProjectsV1Unsupported +} + type EnabledDetectorMock struct{} func (md *EnabledDetectorMock) IssueFeatures() (IssueFeatures, error) { @@ -27,3 +33,7 @@ func (md *EnabledDetectorMock) PullRequestFeatures() (PullRequestFeatures, error func (md *EnabledDetectorMock) RepositoryFeatures() (RepositoryFeatures, error) { return allRepositoryFeatures, nil } + +func (md *EnabledDetectorMock) ProjectsV1() gh.ProjectsV1Support { + return gh.ProjectsV1Supported +} diff --git a/internal/featuredetection/feature_detection.go b/internal/featuredetection/feature_detection.go index a9bbe25f851..fba317f5874 100644 --- a/internal/featuredetection/feature_detection.go +++ b/internal/featuredetection/feature_detection.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/cli/cli/v2/api" + "github.com/cli/cli/v2/internal/gh" "golang.org/x/sync/errgroup" ghauth "github.com/cli/go-gh/v2/pkg/auth" @@ -13,6 +14,7 @@ type Detector interface { IssueFeatures() (IssueFeatures, error) PullRequestFeatures() (PullRequestFeatures, error) RepositoryFeatures() (RepositoryFeatures, error) + ProjectsV1() gh.ProjectsV1Support } type IssueFeatures struct { @@ -199,3 +201,13 @@ func (d *detector) RepositoryFeatures() (RepositoryFeatures, error) { return features, nil } + +func (d *detector) ProjectsV1() gh.ProjectsV1Support { + // Currently, projects v1 support is entirely dependent on the host. As this is deprecated in GHES, + // we will do feature detection on whether the GHES version has support. + if ghauth.IsEnterprise(d.host) { + return gh.ProjectsV1Supported + } + + return gh.ProjectsV1Unsupported +} diff --git a/internal/featuredetection/feature_detection_test.go b/internal/featuredetection/feature_detection_test.go index 8af091c3f01..f1152da2cf7 100644 --- a/internal/featuredetection/feature_detection_test.go +++ b/internal/featuredetection/feature_detection_test.go @@ -5,8 +5,10 @@ import ( "testing" "github.com/MakeNowJust/heredoc" + "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/pkg/httpmock" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestIssueFeatures(t *testing.T) { @@ -366,3 +368,19 @@ func TestRepositoryFeatures(t *testing.T) { }) } } + +func TestProjectV1Support(t *testing.T) { + t.Parallel() + + t.Run("when the host is enterprise, project v1 is supported", func(t *testing.T) { + detector := detector{host: "my.ghes.com"} + isProjectV1Supported := detector.ProjectsV1() + require.Equal(t, gh.ProjectsV1Supported, isProjectV1Supported) + }) + + t.Run("when the host is not enterprise, project v1 is not supported", func(t *testing.T) { + detector := detector{host: "github.com"} + isProjectV1Supported := detector.ProjectsV1() + require.Equal(t, gh.ProjectsV1Unsupported, isProjectV1Supported) + }) +} diff --git a/internal/gh/gh.go b/internal/gh/gh.go index b17c6bd67fb..aa90a5268b6 100644 --- a/internal/gh/gh.go +++ b/internal/gh/gh.go @@ -35,6 +35,10 @@ type Config interface { // Set provides primitive access for setting configuration values, optionally scoped by host. Set(hostname string, key string, value string) + // AccessibleColors returns the configured accessible_colors setting, optionally scoped by host. + AccessibleColors(hostname string) ConfigEntry + // AccessiblePrompter returns the configured accessible_prompter setting, optionally scoped by host. + AccessiblePrompter(hostname string) ConfigEntry // Browser returns the configured browser, optionally scoped by host. Browser(hostname string) ConfigEntry // ColorLabels returns the configured color_label setting, optionally scoped by host. @@ -51,6 +55,8 @@ type Config interface { Prompt(hostname string) ConfigEntry // PreferEditorPrompt returns the configured editor-based prompt, optionally scoped by host. PreferEditorPrompt(hostname string) ConfigEntry + // Spinner returns the configured spinner setting, optionally scoped by host. + Spinner(hostname string) ConfigEntry // Aliases provides persistent storage and modification of command aliases. Aliases() AliasConfig diff --git a/internal/gh/mock/config.go b/internal/gh/mock/config.go index b94cb084dc3..9f3f807993b 100644 --- a/internal/gh/mock/config.go +++ b/internal/gh/mock/config.go @@ -19,6 +19,12 @@ var _ gh.Config = &ConfigMock{} // // // make and configure a mocked gh.Config // mockedConfig := &ConfigMock{ +// AccessibleColorsFunc: func(hostname string) gh.ConfigEntry { +// panic("mock out the AccessibleColors method") +// }, +// AccessiblePrompterFunc: func(hostname string) gh.ConfigEntry { +// panic("mock out the AccessiblePrompter method") +// }, // AliasesFunc: func() gh.AliasConfig { // panic("mock out the Aliases method") // }, @@ -61,6 +67,9 @@ var _ gh.Config = &ConfigMock{} // SetFunc: func(hostname string, key string, value string) { // panic("mock out the Set method") // }, +// SpinnerFunc: func(hostname string) gh.ConfigEntry { +// panic("mock out the Spinner method") +// }, // VersionFunc: func() o.Option[string] { // panic("mock out the Version method") // }, @@ -74,6 +83,12 @@ var _ gh.Config = &ConfigMock{} // // } type ConfigMock struct { + // AccessibleColorsFunc mocks the AccessibleColors method. + AccessibleColorsFunc func(hostname string) gh.ConfigEntry + + // AccessiblePrompterFunc mocks the AccessiblePrompter method. + AccessiblePrompterFunc func(hostname string) gh.ConfigEntry + // AliasesFunc mocks the Aliases method. AliasesFunc func() gh.AliasConfig @@ -116,6 +131,9 @@ type ConfigMock struct { // SetFunc mocks the Set method. SetFunc func(hostname string, key string, value string) + // SpinnerFunc mocks the Spinner method. + SpinnerFunc func(hostname string) gh.ConfigEntry + // VersionFunc mocks the Version method. VersionFunc func() o.Option[string] @@ -124,6 +142,16 @@ type ConfigMock struct { // calls tracks calls to the methods. calls struct { + // AccessibleColors holds details about calls to the AccessibleColors method. + AccessibleColors []struct { + // Hostname is the hostname argument value. + Hostname string + } + // AccessiblePrompter holds details about calls to the AccessiblePrompter method. + AccessiblePrompter []struct { + // Hostname is the hostname argument value. + Hostname string + } // Aliases holds details about calls to the Aliases method. Aliases []struct { } @@ -194,6 +222,11 @@ type ConfigMock struct { // Value is the value argument value. Value string } + // Spinner holds details about calls to the Spinner method. + Spinner []struct { + // Hostname is the hostname argument value. + Hostname string + } // Version holds details about calls to the Version method. Version []struct { } @@ -201,6 +234,8 @@ type ConfigMock struct { Write []struct { } } + lockAccessibleColors sync.RWMutex + lockAccessiblePrompter sync.RWMutex lockAliases sync.RWMutex lockAuthentication sync.RWMutex lockBrowser sync.RWMutex @@ -215,10 +250,75 @@ type ConfigMock struct { lockPreferEditorPrompt sync.RWMutex lockPrompt sync.RWMutex lockSet sync.RWMutex + lockSpinner sync.RWMutex lockVersion sync.RWMutex lockWrite sync.RWMutex } +// AccessibleColors calls AccessibleColorsFunc. +func (mock *ConfigMock) AccessibleColors(hostname string) gh.ConfigEntry { + if mock.AccessibleColorsFunc == nil { + panic("ConfigMock.AccessibleColorsFunc: method is nil but Config.AccessibleColors was just called") + } + callInfo := struct { + Hostname string + }{ + Hostname: hostname, + } + mock.lockAccessibleColors.Lock() + mock.calls.AccessibleColors = append(mock.calls.AccessibleColors, callInfo) + mock.lockAccessibleColors.Unlock() + return mock.AccessibleColorsFunc(hostname) +} + +// AccessibleColorsCalls gets all the calls that were made to AccessibleColors. +// Check the length with: +// +// len(mockedConfig.AccessibleColorsCalls()) +func (mock *ConfigMock) AccessibleColorsCalls() []struct { + Hostname string +} { + var calls []struct { + Hostname string + } + mock.lockAccessibleColors.RLock() + calls = mock.calls.AccessibleColors + mock.lockAccessibleColors.RUnlock() + return calls +} + +// AccessiblePrompter calls AccessiblePrompterFunc. +func (mock *ConfigMock) AccessiblePrompter(hostname string) gh.ConfigEntry { + if mock.AccessiblePrompterFunc == nil { + panic("ConfigMock.AccessiblePrompterFunc: method is nil but Config.AccessiblePrompter was just called") + } + callInfo := struct { + Hostname string + }{ + Hostname: hostname, + } + mock.lockAccessiblePrompter.Lock() + mock.calls.AccessiblePrompter = append(mock.calls.AccessiblePrompter, callInfo) + mock.lockAccessiblePrompter.Unlock() + return mock.AccessiblePrompterFunc(hostname) +} + +// AccessiblePrompterCalls gets all the calls that were made to AccessiblePrompter. +// Check the length with: +// +// len(mockedConfig.AccessiblePrompterCalls()) +func (mock *ConfigMock) AccessiblePrompterCalls() []struct { + Hostname string +} { + var calls []struct { + Hostname string + } + mock.lockAccessiblePrompter.RLock() + calls = mock.calls.AccessiblePrompter + mock.lockAccessiblePrompter.RUnlock() + return calls +} + // Aliases calls AliasesFunc. func (mock *ConfigMock) Aliases() gh.AliasConfig { if mock.AliasesFunc == nil { @@ -664,6 +764,38 @@ func (mock *ConfigMock) SetCalls() []struct { return calls } +// Spinner calls SpinnerFunc. +func (mock *ConfigMock) Spinner(hostname string) gh.ConfigEntry { + if mock.SpinnerFunc == nil { + panic("ConfigMock.SpinnerFunc: method is nil but Config.Spinner was just called") + } + callInfo := struct { + Hostname string + }{ + Hostname: hostname, + } + mock.lockSpinner.Lock() + mock.calls.Spinner = append(mock.calls.Spinner, callInfo) + mock.lockSpinner.Unlock() + return mock.SpinnerFunc(hostname) +} + +// SpinnerCalls gets all the calls that were made to Spinner. +// Check the length with: +// +// len(mockedConfig.SpinnerCalls()) +func (mock *ConfigMock) SpinnerCalls() []struct { + Hostname string +} { + var calls []struct { + Hostname string + } + mock.lockSpinner.RLock() + calls = mock.calls.Spinner + mock.lockSpinner.RUnlock() + return calls +} + // Version calls VersionFunc. func (mock *ConfigMock) Version() o.Option[string] { if mock.VersionFunc == nil { diff --git a/internal/gh/projects.go b/internal/gh/projects.go new file mode 100644 index 00000000000..34acf8d7c58 --- /dev/null +++ b/internal/gh/projects.go @@ -0,0 +1,23 @@ +package gh + +// ProjectsV1Support provides type safety and readability around whether or not Projects v1 is supported +// by the targeted host. +// +// It is a sealed type to ensure that consumers must use the exported ProjectsV1Supported and ProjectsV1Unsupported +// variables to get an instance of the type. +type ProjectsV1Support interface { + sealed() +} + +type projectsV1Supported struct{} + +func (projectsV1Supported) sealed() {} + +type projectsV1Unsupported struct{} + +func (projectsV1Unsupported) sealed() {} + +var ( + ProjectsV1Supported ProjectsV1Support = projectsV1Supported{} + ProjectsV1Unsupported ProjectsV1Support = projectsV1Unsupported{} +) diff --git a/internal/prompter/accessible_prompter_test.go b/internal/prompter/accessible_prompter_test.go index 56096972d12..619eb14f131 100644 --- a/internal/prompter/accessible_prompter_test.go +++ b/internal/prompter/accessible_prompter_test.go @@ -11,6 +11,7 @@ import ( "github.com/Netflix/go-expect" "github.com/cli/cli/v2/internal/prompter" + "github.com/cli/cli/v2/pkg/iostreams" "github.com/creack/pty" "github.com/hinshun/vt10x" "github.com/stretchr/testify/assert" @@ -33,7 +34,7 @@ import ( func TestAccessiblePrompter(t *testing.T) { t.Run("Select", func(t *testing.T) { console := newTestVirtualTerminal(t) - p := newTestAcessiblePrompter(t, console) + p := newTestAccessiblePrompter(t, console) go func() { // Wait for prompt to appear @@ -52,7 +53,7 @@ func TestAccessiblePrompter(t *testing.T) { t.Run("MultiSelect", func(t *testing.T) { console := newTestVirtualTerminal(t) - p := newTestAcessiblePrompter(t, console) + p := newTestAccessiblePrompter(t, console) go func() { // Wait for prompt to appear @@ -77,7 +78,7 @@ func TestAccessiblePrompter(t *testing.T) { t.Run("Input", func(t *testing.T) { console := newTestVirtualTerminal(t) - p := newTestAcessiblePrompter(t, console) + p := newTestAccessiblePrompter(t, console) dummyText := "12345abcdefg" go func() { @@ -97,7 +98,7 @@ func TestAccessiblePrompter(t *testing.T) { t.Run("Input - blank input returns default value", func(t *testing.T) { console := newTestVirtualTerminal(t) - p := newTestAcessiblePrompter(t, console) + p := newTestAccessiblePrompter(t, console) dummyDefaultValue := "12345abcdefg" go func() { @@ -117,7 +118,7 @@ func TestAccessiblePrompter(t *testing.T) { t.Run("Password", func(t *testing.T) { console := newTestVirtualTerminal(t) - p := newTestAcessiblePrompter(t, console) + p := newTestAccessiblePrompter(t, console) dummyPassword := "12345abcdefg" go func() { @@ -137,7 +138,7 @@ func TestAccessiblePrompter(t *testing.T) { t.Run("Confirm", func(t *testing.T) { console := newTestVirtualTerminal(t) - p := newTestAcessiblePrompter(t, console) + p := newTestAccessiblePrompter(t, console) go func() { // Wait for prompt to appear @@ -156,7 +157,7 @@ func TestAccessiblePrompter(t *testing.T) { t.Run("Confirm - blank input returns default", func(t *testing.T) { console := newTestVirtualTerminal(t) - p := newTestAcessiblePrompter(t, console) + p := newTestAccessiblePrompter(t, console) go func() { // Wait for prompt to appear @@ -175,7 +176,7 @@ func TestAccessiblePrompter(t *testing.T) { t.Run("AuthToken", func(t *testing.T) { console := newTestVirtualTerminal(t) - p := newTestAcessiblePrompter(t, console) + p := newTestAccessiblePrompter(t, console) dummyAuthToken := "12345abcdefg" go func() { @@ -195,7 +196,7 @@ func TestAccessiblePrompter(t *testing.T) { t.Run("AuthToken - blank input returns error", func(t *testing.T) { console := newTestVirtualTerminal(t) - p := newTestAcessiblePrompter(t, console) + p := newTestAccessiblePrompter(t, console) dummyAuthTokenForAfterFailure := "12345abcdefg" go func() { @@ -223,7 +224,7 @@ func TestAccessiblePrompter(t *testing.T) { t.Run("ConfirmDeletion", func(t *testing.T) { console := newTestVirtualTerminal(t) - p := newTestAcessiblePrompter(t, console) + p := newTestAccessiblePrompter(t, console) requiredValue := "test" go func() { @@ -243,7 +244,7 @@ func TestAccessiblePrompter(t *testing.T) { t.Run("ConfirmDeletion - bad input", func(t *testing.T) { console := newTestVirtualTerminal(t) - p := newTestAcessiblePrompter(t, console) + p := newTestAccessiblePrompter(t, console) requiredValue := "test" badInputValue := "garbage" @@ -272,7 +273,7 @@ func TestAccessiblePrompter(t *testing.T) { t.Run("InputHostname", func(t *testing.T) { console := newTestVirtualTerminal(t) - p := newTestAcessiblePrompter(t, console) + p := newTestAccessiblePrompter(t, console) hostname := "example.com" go func() { @@ -292,7 +293,7 @@ func TestAccessiblePrompter(t *testing.T) { t.Run("MarkdownEditor - blank allowed with blank input returns blank", func(t *testing.T) { console := newTestVirtualTerminal(t) - p := newTestAcessiblePrompter(t, console) + p := newTestAccessiblePrompter(t, console) go func() { // Wait for prompt to appear @@ -311,7 +312,7 @@ func TestAccessiblePrompter(t *testing.T) { t.Run("MarkdownEditor - blank disallowed with default value returns default value", func(t *testing.T) { console := newTestVirtualTerminal(t) - p := newTestAcessiblePrompter(t, console) + p := newTestAccessiblePrompter(t, console) defaultValue := "12345abcdefg" go func() { @@ -339,7 +340,7 @@ func TestAccessiblePrompter(t *testing.T) { t.Run("MarkdownEditor - blank disallowed no default value returns error", func(t *testing.T) { console := newTestVirtualTerminal(t) - p := newTestAcessiblePrompter(t, console) + p := newTestAccessiblePrompter(t, console) go func() { // Wait for prompt to appear @@ -419,21 +420,40 @@ func newTestVirtualTerminal(t *testing.T) *expect.Console { return console } -func newTestAcessiblePrompter(t *testing.T, console *expect.Console) prompter.Prompter { +func newTestVirtualTerminalIOStreams(t *testing.T, console *expect.Console) *iostreams.IOStreams { t.Helper() + io := &iostreams.IOStreams{ + In: console.Tty(), + Out: console.Tty(), + ErrOut: console.Tty(), + } + io.SetStdinTTY(false) + io.SetStdoutTTY(false) + io.SetStderrTTY(false) + return io +} - t.Setenv("GH_ACCESSIBLE_PROMPTER", "true") - // `echo`` is chose as the editor command because it immediately returns - // a success exit code, returns an empty string, doesn't require any user input, - // and since this file is only built on Linux, it is near guaranteed to be available. - return prompter.New("echo", console.Tty(), console.Tty(), console.Tty()) +// `echo` is chosen as the editor command because it immediately returns +// a success exit code, returns an empty string, doesn't require any user input, +// and since this file is only built on Linux, it is near guaranteed to be available. +var editorCmd = "echo" + +func newTestAccessiblePrompter(t *testing.T, console *expect.Console) prompter.Prompter { + t.Helper() + + io := newTestVirtualTerminalIOStreams(t, console) + io.SetAccessiblePrompterEnabled(true) + + return prompter.New(editorCmd, io) } func newTestSurveyPrompter(t *testing.T, console *expect.Console) prompter.Prompter { t.Helper() - t.Setenv("GH_ACCESSIBLE_PROMPTER", "false") - return prompter.New("echo", console.Tty(), console.Tty(), console.Tty()) + io := newTestVirtualTerminalIOStreams(t, console) + io.SetAccessiblePrompterEnabled(false) + + return prompter.New(editorCmd, io) } // failOnExpectError adds an observer that will fail the test in a standardised way diff --git a/internal/prompter/prompter.go b/internal/prompter/prompter.go index 6ef61cf1583..2a432836668 100644 --- a/internal/prompter/prompter.go +++ b/internal/prompter/prompter.go @@ -2,13 +2,12 @@ package prompter import ( "fmt" - "os" - "slices" "strings" "github.com/AlecAivazis/survey/v2" "github.com/charmbracelet/huh" "github.com/cli/cli/v2/internal/ghinstance" + "github.com/cli/cli/v2/pkg/iostreams" "github.com/cli/cli/v2/pkg/surveyext" ghPrompter "github.com/cli/go-gh/v2/pkg/prompter" ) @@ -43,24 +42,21 @@ type Prompter interface { MarkdownEditor(prompt string, defaultValue string, blankAllowed bool) (string, error) } -func New(editorCmd string, stdin ghPrompter.FileReader, stdout ghPrompter.FileWriter, stderr ghPrompter.FileWriter) Prompter { - accessiblePrompterValue, accessiblePrompterIsSet := os.LookupEnv("GH_ACCESSIBLE_PROMPTER") - falseyValues := []string{"false", "0", "no", ""} - - if accessiblePrompterIsSet && !slices.Contains(falseyValues, accessiblePrompterValue) { +func New(editorCmd string, io *iostreams.IOStreams) Prompter { + if io.AccessiblePrompterEnabled() { return &accessiblePrompter{ - stdin: stdin, - stdout: stdout, - stderr: stderr, + stdin: io.In, + stdout: io.Out, + stderr: io.ErrOut, editorCmd: editorCmd, } } return &surveyPrompter{ - prompter: ghPrompter.New(stdin, stdout, stderr), - stdin: stdin, - stdout: stdout, - stderr: stderr, + prompter: ghPrompter.New(io.In, io.Out, io.ErrOut), + stdin: io.In, + stdout: io.Out, + stderr: io.ErrOut, editorCmd: editorCmd, } } diff --git a/internal/prompter/prompter_mock.go b/internal/prompter/prompter_mock.go index b817a491f99..b15f8bf96a7 100644 --- a/internal/prompter/prompter_mock.go +++ b/internal/prompter/prompter_mock.go @@ -20,28 +20,28 @@ var _ Prompter = &PrompterMock{} // AuthTokenFunc: func() (string, error) { // panic("mock out the AuthToken method") // }, -// ConfirmFunc: func(s string, b bool) (bool, error) { +// ConfirmFunc: func(prompt string, defaultValue bool) (bool, error) { // panic("mock out the Confirm method") // }, -// ConfirmDeletionFunc: func(s string) error { +// ConfirmDeletionFunc: func(requiredValue string) error { // panic("mock out the ConfirmDeletion method") // }, -// InputFunc: func(s1 string, s2 string) (string, error) { +// InputFunc: func(prompt string, defaultValue string) (string, error) { // panic("mock out the Input method") // }, // InputHostnameFunc: func() (string, error) { // panic("mock out the InputHostname method") // }, -// MarkdownEditorFunc: func(s1 string, s2 string, b bool) (string, error) { +// MarkdownEditorFunc: func(prompt string, defaultValue string, blankAllowed bool) (string, error) { // panic("mock out the MarkdownEditor method") // }, // MultiSelectFunc: func(prompt string, defaults []string, options []string) ([]int, error) { // panic("mock out the MultiSelect method") // }, -// PasswordFunc: func(s string) (string, error) { +// PasswordFunc: func(prompt string) (string, error) { // panic("mock out the Password method") // }, -// SelectFunc: func(s1 string, s2 string, strings []string) (int, error) { +// SelectFunc: func(prompt string, defaultValue string, options []string) (int, error) { // panic("mock out the Select method") // }, // } @@ -55,28 +55,28 @@ type PrompterMock struct { AuthTokenFunc func() (string, error) // ConfirmFunc mocks the Confirm method. - ConfirmFunc func(s string, b bool) (bool, error) + ConfirmFunc func(prompt string, defaultValue bool) (bool, error) // ConfirmDeletionFunc mocks the ConfirmDeletion method. - ConfirmDeletionFunc func(s string) error + ConfirmDeletionFunc func(requiredValue string) error // InputFunc mocks the Input method. - InputFunc func(s1 string, s2 string) (string, error) + InputFunc func(prompt string, defaultValue string) (string, error) // InputHostnameFunc mocks the InputHostname method. InputHostnameFunc func() (string, error) // MarkdownEditorFunc mocks the MarkdownEditor method. - MarkdownEditorFunc func(s1 string, s2 string, b bool) (string, error) + MarkdownEditorFunc func(prompt string, defaultValue string, blankAllowed bool) (string, error) // MultiSelectFunc mocks the MultiSelect method. MultiSelectFunc func(prompt string, defaults []string, options []string) ([]int, error) // PasswordFunc mocks the Password method. - PasswordFunc func(s string) (string, error) + PasswordFunc func(prompt string) (string, error) // SelectFunc mocks the Select method. - SelectFunc func(s1 string, s2 string, strings []string) (int, error) + SelectFunc func(prompt string, defaultValue string, options []string) (int, error) // calls tracks calls to the methods. calls struct { @@ -85,34 +85,34 @@ type PrompterMock struct { } // Confirm holds details about calls to the Confirm method. Confirm []struct { - // S is the s argument value. - S string - // B is the b argument value. - B bool + // Prompt is the prompt argument value. + Prompt string + // DefaultValue is the defaultValue argument value. + DefaultValue bool } // ConfirmDeletion holds details about calls to the ConfirmDeletion method. ConfirmDeletion []struct { - // S is the s argument value. - S string + // RequiredValue is the requiredValue argument value. + RequiredValue string } // Input holds details about calls to the Input method. Input []struct { - // S1 is the s1 argument value. - S1 string - // S2 is the s2 argument value. - S2 string + // Prompt is the prompt argument value. + Prompt string + // DefaultValue is the defaultValue argument value. + DefaultValue string } // InputHostname holds details about calls to the InputHostname method. InputHostname []struct { } // MarkdownEditor holds details about calls to the MarkdownEditor method. MarkdownEditor []struct { - // S1 is the s1 argument value. - S1 string - // S2 is the s2 argument value. - S2 string - // B is the b argument value. - B bool + // Prompt is the prompt argument value. + Prompt string + // DefaultValue is the defaultValue argument value. + DefaultValue string + // BlankAllowed is the blankAllowed argument value. + BlankAllowed bool } // MultiSelect holds details about calls to the MultiSelect method. MultiSelect []struct { @@ -125,17 +125,17 @@ type PrompterMock struct { } // Password holds details about calls to the Password method. Password []struct { - // S is the s argument value. - S string + // Prompt is the prompt argument value. + Prompt string } // Select holds details about calls to the Select method. Select []struct { - // S1 is the s1 argument value. - S1 string - // S2 is the s2 argument value. - S2 string - // Strings is the strings argument value. - Strings []string + // Prompt is the prompt argument value. + Prompt string + // DefaultValue is the defaultValue argument value. + DefaultValue string + // Options is the options argument value. + Options []string } } lockAuthToken sync.RWMutex @@ -177,21 +177,21 @@ func (mock *PrompterMock) AuthTokenCalls() []struct { } // Confirm calls ConfirmFunc. -func (mock *PrompterMock) Confirm(s string, b bool) (bool, error) { +func (mock *PrompterMock) Confirm(prompt string, defaultValue bool) (bool, error) { if mock.ConfirmFunc == nil { panic("PrompterMock.ConfirmFunc: method is nil but Prompter.Confirm was just called") } callInfo := struct { - S string - B bool + Prompt string + DefaultValue bool }{ - S: s, - B: b, + Prompt: prompt, + DefaultValue: defaultValue, } mock.lockConfirm.Lock() mock.calls.Confirm = append(mock.calls.Confirm, callInfo) mock.lockConfirm.Unlock() - return mock.ConfirmFunc(s, b) + return mock.ConfirmFunc(prompt, defaultValue) } // ConfirmCalls gets all the calls that were made to Confirm. @@ -199,12 +199,12 @@ func (mock *PrompterMock) Confirm(s string, b bool) (bool, error) { // // len(mockedPrompter.ConfirmCalls()) func (mock *PrompterMock) ConfirmCalls() []struct { - S string - B bool + Prompt string + DefaultValue bool } { var calls []struct { - S string - B bool + Prompt string + DefaultValue bool } mock.lockConfirm.RLock() calls = mock.calls.Confirm @@ -213,19 +213,19 @@ func (mock *PrompterMock) ConfirmCalls() []struct { } // ConfirmDeletion calls ConfirmDeletionFunc. -func (mock *PrompterMock) ConfirmDeletion(s string) error { +func (mock *PrompterMock) ConfirmDeletion(requiredValue string) error { if mock.ConfirmDeletionFunc == nil { panic("PrompterMock.ConfirmDeletionFunc: method is nil but Prompter.ConfirmDeletion was just called") } callInfo := struct { - S string + RequiredValue string }{ - S: s, + RequiredValue: requiredValue, } mock.lockConfirmDeletion.Lock() mock.calls.ConfirmDeletion = append(mock.calls.ConfirmDeletion, callInfo) mock.lockConfirmDeletion.Unlock() - return mock.ConfirmDeletionFunc(s) + return mock.ConfirmDeletionFunc(requiredValue) } // ConfirmDeletionCalls gets all the calls that were made to ConfirmDeletion. @@ -233,10 +233,10 @@ func (mock *PrompterMock) ConfirmDeletion(s string) error { // // len(mockedPrompter.ConfirmDeletionCalls()) func (mock *PrompterMock) ConfirmDeletionCalls() []struct { - S string + RequiredValue string } { var calls []struct { - S string + RequiredValue string } mock.lockConfirmDeletion.RLock() calls = mock.calls.ConfirmDeletion @@ -245,21 +245,21 @@ func (mock *PrompterMock) ConfirmDeletionCalls() []struct { } // Input calls InputFunc. -func (mock *PrompterMock) Input(s1 string, s2 string) (string, error) { +func (mock *PrompterMock) Input(prompt string, defaultValue string) (string, error) { if mock.InputFunc == nil { panic("PrompterMock.InputFunc: method is nil but Prompter.Input was just called") } callInfo := struct { - S1 string - S2 string + Prompt string + DefaultValue string }{ - S1: s1, - S2: s2, + Prompt: prompt, + DefaultValue: defaultValue, } mock.lockInput.Lock() mock.calls.Input = append(mock.calls.Input, callInfo) mock.lockInput.Unlock() - return mock.InputFunc(s1, s2) + return mock.InputFunc(prompt, defaultValue) } // InputCalls gets all the calls that were made to Input. @@ -267,12 +267,12 @@ func (mock *PrompterMock) Input(s1 string, s2 string) (string, error) { // // len(mockedPrompter.InputCalls()) func (mock *PrompterMock) InputCalls() []struct { - S1 string - S2 string + Prompt string + DefaultValue string } { var calls []struct { - S1 string - S2 string + Prompt string + DefaultValue string } mock.lockInput.RLock() calls = mock.calls.Input @@ -308,23 +308,23 @@ func (mock *PrompterMock) InputHostnameCalls() []struct { } // MarkdownEditor calls MarkdownEditorFunc. -func (mock *PrompterMock) MarkdownEditor(s1 string, s2 string, b bool) (string, error) { +func (mock *PrompterMock) MarkdownEditor(prompt string, defaultValue string, blankAllowed bool) (string, error) { if mock.MarkdownEditorFunc == nil { panic("PrompterMock.MarkdownEditorFunc: method is nil but Prompter.MarkdownEditor was just called") } callInfo := struct { - S1 string - S2 string - B bool + Prompt string + DefaultValue string + BlankAllowed bool }{ - S1: s1, - S2: s2, - B: b, + Prompt: prompt, + DefaultValue: defaultValue, + BlankAllowed: blankAllowed, } mock.lockMarkdownEditor.Lock() mock.calls.MarkdownEditor = append(mock.calls.MarkdownEditor, callInfo) mock.lockMarkdownEditor.Unlock() - return mock.MarkdownEditorFunc(s1, s2, b) + return mock.MarkdownEditorFunc(prompt, defaultValue, blankAllowed) } // MarkdownEditorCalls gets all the calls that were made to MarkdownEditor. @@ -332,14 +332,14 @@ func (mock *PrompterMock) MarkdownEditor(s1 string, s2 string, b bool) (string, // // len(mockedPrompter.MarkdownEditorCalls()) func (mock *PrompterMock) MarkdownEditorCalls() []struct { - S1 string - S2 string - B bool + Prompt string + DefaultValue string + BlankAllowed bool } { var calls []struct { - S1 string - S2 string - B bool + Prompt string + DefaultValue string + BlankAllowed bool } mock.lockMarkdownEditor.RLock() calls = mock.calls.MarkdownEditor @@ -388,19 +388,19 @@ func (mock *PrompterMock) MultiSelectCalls() []struct { } // Password calls PasswordFunc. -func (mock *PrompterMock) Password(s string) (string, error) { +func (mock *PrompterMock) Password(prompt string) (string, error) { if mock.PasswordFunc == nil { panic("PrompterMock.PasswordFunc: method is nil but Prompter.Password was just called") } callInfo := struct { - S string + Prompt string }{ - S: s, + Prompt: prompt, } mock.lockPassword.Lock() mock.calls.Password = append(mock.calls.Password, callInfo) mock.lockPassword.Unlock() - return mock.PasswordFunc(s) + return mock.PasswordFunc(prompt) } // PasswordCalls gets all the calls that were made to Password. @@ -408,10 +408,10 @@ func (mock *PrompterMock) Password(s string) (string, error) { // // len(mockedPrompter.PasswordCalls()) func (mock *PrompterMock) PasswordCalls() []struct { - S string + Prompt string } { var calls []struct { - S string + Prompt string } mock.lockPassword.RLock() calls = mock.calls.Password @@ -420,23 +420,23 @@ func (mock *PrompterMock) PasswordCalls() []struct { } // Select calls SelectFunc. -func (mock *PrompterMock) Select(s1 string, s2 string, strings []string) (int, error) { +func (mock *PrompterMock) Select(prompt string, defaultValue string, options []string) (int, error) { if mock.SelectFunc == nil { panic("PrompterMock.SelectFunc: method is nil but Prompter.Select was just called") } callInfo := struct { - S1 string - S2 string - Strings []string + Prompt string + DefaultValue string + Options []string }{ - S1: s1, - S2: s2, - Strings: strings, + Prompt: prompt, + DefaultValue: defaultValue, + Options: options, } mock.lockSelect.Lock() mock.calls.Select = append(mock.calls.Select, callInfo) mock.lockSelect.Unlock() - return mock.SelectFunc(s1, s2, strings) + return mock.SelectFunc(prompt, defaultValue, options) } // SelectCalls gets all the calls that were made to Select. @@ -444,14 +444,14 @@ func (mock *PrompterMock) Select(s1 string, s2 string, strings []string) (int, e // // len(mockedPrompter.SelectCalls()) func (mock *PrompterMock) SelectCalls() []struct { - S1 string - S2 string - Strings []string + Prompt string + DefaultValue string + Options []string } { var calls []struct { - S1 string - S2 string - Strings []string + Prompt string + DefaultValue string + Options []string } mock.lockSelect.RLock() calls = mock.calls.Select diff --git a/internal/run/stub.go b/internal/run/stub.go index 5cd3c6de59c..507fd61d6f9 100644 --- a/internal/run/stub.go +++ b/internal/run/stub.go @@ -46,7 +46,7 @@ func Stub() (*CommandStubber, func(T)) { return } t.Helper() - t.Errorf("unmatched stubs (%d): %s", len(unmatched), strings.Join(unmatched, ", ")) + t.Errorf("unmatched exec stubs (%d): %s", len(unmatched), strings.Join(unmatched, ", ")) } } diff --git a/pkg/cmd/attestation/inspect/inspect.go b/pkg/cmd/attestation/inspect/inspect.go index 6fbddd6da35..b571eee0100 100644 --- a/pkg/cmd/attestation/inspect/inspect.go +++ b/pkg/cmd/attestation/inspect/inspect.go @@ -105,7 +105,11 @@ func NewInspectCmd(f *cmdutil.Factory, runF func(*Options) error) *cobra.Command config.TrustDomain = td } - opts.SigstoreVerifier = verification.NewLiveSigstoreVerifier(config) + sgVerifier, err := verification.NewLiveSigstoreVerifier(config) + if err != nil { + return fmt.Errorf("failed to create Sigstore verifier: %w", err) + } + opts.SigstoreVerifier = sgVerifier if runF != nil { return runF(opts) diff --git a/pkg/cmd/attestation/verification/sigstore.go b/pkg/cmd/attestation/verification/sigstore.go index 6dd31dac0da..190ea5c0f1e 100644 --- a/pkg/cmd/attestation/verification/sigstore.go +++ b/pkg/cmd/attestation/verification/sigstore.go @@ -44,12 +44,11 @@ type SigstoreVerifier interface { } type LiveSigstoreVerifier struct { - TrustedRoot string Logger *io.Handler NoPublicGood bool - // If tenancy mode is not used, trust domain is empty - TrustDomain string - TUFMetadataDir o.Option[string] + PublicGood *verify.SignedEntityVerifier + GitHub *verify.SignedEntityVerifier + Custom map[string]*verify.SignedEntityVerifier } var ErrNoAttestationsVerified = errors.New("no attestations were verified") @@ -57,56 +56,43 @@ var ErrNoAttestationsVerified = errors.New("no attestations were verified") // NewLiveSigstoreVerifier creates a new LiveSigstoreVerifier struct // that is used to verify artifacts and attestations against the // Public Good, GitHub, or a custom trusted root. -func NewLiveSigstoreVerifier(config SigstoreConfig) *LiveSigstoreVerifier { - return &LiveSigstoreVerifier{ - TrustedRoot: config.TrustedRoot, - Logger: config.Logger, - NoPublicGood: config.NoPublicGood, - TrustDomain: config.TrustDomain, - TUFMetadataDir: config.TUFMetadataDir, +func NewLiveSigstoreVerifier(config SigstoreConfig) (*LiveSigstoreVerifier, error) { + liveVerifier := &LiveSigstoreVerifier{ + Logger: config.Logger, + NoPublicGood: config.NoPublicGood, } -} - -func getBundleIssuer(b *bundle.Bundle) (string, error) { - if !b.MinVersion("0.2") { - return "", fmt.Errorf("unsupported bundle version: %s", b.MediaType) - } - verifyContent, err := b.VerificationContent() - if err != nil { - return "", fmt.Errorf("failed to get bundle verification content: %v", err) + // if a custom trusted root is set, configure custom verifiers + if config.TrustedRoot != "" { + customVerifiers, err := createCustomVerifiers(config.TrustedRoot, config.NoPublicGood) + if err != nil { + return nil, err + } + liveVerifier.Custom = customVerifiers + return liveVerifier, nil } - leafCert := verifyContent.Certificate() - if leafCert == nil { - return "", fmt.Errorf("leaf cert not found") + if !config.NoPublicGood { + publicGoodVerifier, err := newPublicGoodVerifier(config.TUFMetadataDir) + if err != nil { + return nil, err + } + liveVerifier.PublicGood = publicGoodVerifier } - if len(leafCert.Issuer.Organization) != 1 { - return "", fmt.Errorf("expected the leaf certificate issuer to only have one organization") + github, err := newGitHubVerifier(config.TrustDomain, config.TUFMetadataDir) + if err != nil { + return nil, err } - return leafCert.Issuer.Organization[0], nil -} + liveVerifier.GitHub = github -func (v *LiveSigstoreVerifier) chooseVerifier(issuer string) (*verify.SignedEntityVerifier, error) { - // if no custom trusted root is set, attempt to create a Public Good or - // GitHub Sigstore verifier - if v.TrustedRoot == "" { - switch issuer { - case PublicGoodIssuerOrg: - if v.NoPublicGood { - return nil, fmt.Errorf("detected public good instance but requested verification without public good instance") - } - return newPublicGoodVerifier(v.TUFMetadataDir) - case GitHubIssuerOrg: - return newGitHubVerifier(v.TrustDomain, v.TUFMetadataDir) - default: - return nil, fmt.Errorf("leaf certificate issuer is not recognized") - } - } + return liveVerifier, nil +} - customTrustRoots, err := os.ReadFile(v.TrustedRoot) +func createCustomVerifiers(trustedRoot string, noPublicGood bool) (map[string]*verify.SignedEntityVerifier, error) { + customTrustRoots, err := os.ReadFile(trustedRoot) if err != nil { - return nil, fmt.Errorf("unable to read file %s: %v", v.TrustedRoot, err) + return nil, fmt.Errorf("unable to read file %s: %v", trustedRoot, err) } + verifiers := make(map[string]*verify.SignedEntityVerifier) reader := bufio.NewReader(bytes.NewReader(customTrustRoots)) var line []byte var readError error @@ -130,10 +116,11 @@ func (v *LiveSigstoreVerifier) chooseVerifier(issuer string) (*verify.SignedEnti return nil, err } - // if the custom trusted root issuer is not set or doesn't match the given issuer, skip it - if len(lowestCert.Issuer.Organization) == 0 || lowestCert.Issuer.Organization[0] != issuer { + // if the custom trusted root issuer is not set, skip it + if len(lowestCert.Issuer.Organization) == 0 { continue } + issuer := lowestCert.Issuer.Organization[0] // Determine what policy to use with this trusted root. // @@ -141,21 +128,88 @@ func (v *LiveSigstoreVerifier) chooseVerifier(issuer string) (*verify.SignedEnti // issuer. We *must* use the trusted root provided. switch issuer { case PublicGoodIssuerOrg: - if v.NoPublicGood { + if noPublicGood { return nil, fmt.Errorf("detected public good instance but requested verification without public good instance") } - return newPublicGoodVerifierWithTrustedRoot(trustedRoot) + if _, ok := verifiers[PublicGoodIssuerOrg]; ok { + // we have already created a public good verifier with this custom trusted root + // so we skip it + continue + } + publicGood, err := newPublicGoodVerifierWithTrustedRoot(trustedRoot) + if err != nil { + return nil, err + } + verifiers[PublicGoodIssuerOrg] = publicGood case GitHubIssuerOrg: - return newGitHubVerifierWithTrustedRoot(trustedRoot) + if _, ok := verifiers[GitHubIssuerOrg]; ok { + // we have already created a github verifier with this custom trusted root + // so we skip it + continue + } + github, err := newGitHubVerifierWithTrustedRoot(trustedRoot) + if err != nil { + return nil, err + } + verifiers[GitHubIssuerOrg] = github default: + if _, ok := verifiers[issuer]; ok { + // we have already created a custom verifier with this custom trusted root + // so we skip it + continue + } // Make best guess at reasonable policy - return newCustomVerifier(trustedRoot) + custom, err := newCustomVerifier(trustedRoot) + if err != nil { + return nil, err + } + verifiers[issuer] = custom } } line, readError = reader.ReadBytes('\n') } + return verifiers, nil +} + +func getBundleIssuer(b *bundle.Bundle) (string, error) { + if !b.MinVersion("0.2") { + return "", fmt.Errorf("unsupported bundle version: %s", b.MediaType) + } + verifyContent, err := b.VerificationContent() + if err != nil { + return "", fmt.Errorf("failed to get bundle verification content: %v", err) + } + leafCert := verifyContent.Certificate() + if leafCert == nil { + return "", fmt.Errorf("leaf cert not found") + } + if len(leafCert.Issuer.Organization) != 1 { + return "", fmt.Errorf("expected the leaf certificate issuer to only have one organization") + } + return leafCert.Issuer.Organization[0], nil +} - return nil, fmt.Errorf("unable to use provided trusted roots") +func (v *LiveSigstoreVerifier) chooseVerifier(issuer string) (*verify.SignedEntityVerifier, error) { + // if no custom trusted root is set, return either the Public Good or GitHub verifier + // If the chosen verifier has not yet been created, create it as a LiveSigstoreVerifier field for use in future calls + if v.Custom != nil { + custom, ok := v.Custom[issuer] + if !ok { + return nil, fmt.Errorf("no custom verifier found for issuer \"%s\"", issuer) + } + return custom, nil + } + switch issuer { + case PublicGoodIssuerOrg: + if v.NoPublicGood { + return nil, fmt.Errorf("detected public good instance but requested verification without public good instance") + } + return v.PublicGood, nil + case GitHubIssuerOrg: + return v.GitHub, nil + default: + return nil, fmt.Errorf("leaf certificate issuer is not recognized") + } } func getLowestCertInChain(ca *root.FulcioCertificateAuthority) (*x509.Certificate, error) { @@ -177,7 +231,7 @@ func (v *LiveSigstoreVerifier) verify(attestation *api.Attestation, policy verif // determine which verifier should attempt verification against the bundle verifier, err := v.chooseVerifier(issuer) if err != nil { - return nil, fmt.Errorf("failed to find recognized issuer from bundle content: %v", err) + return nil, fmt.Errorf("failed to choose verifier based on provided bundle issuer: %v", err) } v.Logger.VerbosePrintf("Attempting verification against issuer \"%s\"\n", issuer) diff --git a/pkg/cmd/attestation/verification/sigstore_integration_test.go b/pkg/cmd/attestation/verification/sigstore_integration_test.go index 987fb9caa07..2a2d3beeac6 100644 --- a/pkg/cmd/attestation/verification/sigstore_integration_test.go +++ b/pkg/cmd/attestation/verification/sigstore_integration_test.go @@ -50,10 +50,11 @@ func TestLiveSigstoreVerifier(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - verifier := NewLiveSigstoreVerifier(SigstoreConfig{ + verifier, err := NewLiveSigstoreVerifier(SigstoreConfig{ Logger: io.NewTestHandler(), TUFMetadataDir: o.Some(t.TempDir()), }) + require.NoError(t, err) results, err := verifier.Verify(tc.attestations, publicGoodPolicy(t)) @@ -69,10 +70,11 @@ func TestLiveSigstoreVerifier(t *testing.T) { } t.Run("with 2/3 verified attestations", func(t *testing.T) { - verifier := NewLiveSigstoreVerifier(SigstoreConfig{ + verifier, err := NewLiveSigstoreVerifier(SigstoreConfig{ Logger: io.NewTestHandler(), TUFMetadataDir: o.Some(t.TempDir()), }) + require.NoError(t, err) invalidBundle := getAttestationsFor(t, "../test/data/sigstore-js-2.1.0-bundle-v0.1.json") attestations := getAttestationsFor(t, "../test/data/sigstore-js-2.1.0_with_2_bundles.jsonl") @@ -86,10 +88,11 @@ func TestLiveSigstoreVerifier(t *testing.T) { }) t.Run("fail with 0/2 verified attestations", func(t *testing.T) { - verifier := NewLiveSigstoreVerifier(SigstoreConfig{ + verifier, err := NewLiveSigstoreVerifier(SigstoreConfig{ Logger: io.NewTestHandler(), TUFMetadataDir: o.Some(t.TempDir()), }) + require.NoError(t, err) invalidBundle := getAttestationsFor(t, "../test/data/sigstore-js-2.1.0-bundle-v0.1.json") attestations := getAttestationsFor(t, "../test/data/sigstoreBundle-invalid-signature.json") @@ -110,10 +113,11 @@ func TestLiveSigstoreVerifier(t *testing.T) { attestations := getAttestationsFor(t, "../test/data/github_provenance_demo-0.0.12-py3-none-any-bundle.jsonl") - verifier := NewLiveSigstoreVerifier(SigstoreConfig{ + verifier, err := NewLiveSigstoreVerifier(SigstoreConfig{ Logger: io.NewTestHandler(), TUFMetadataDir: o.Some(t.TempDir()), }) + require.NoError(t, err) results, err := verifier.Verify(attestations, githubPolicy) require.Len(t, results, 1) @@ -123,11 +127,12 @@ func TestLiveSigstoreVerifier(t *testing.T) { t.Run("with custom trusted root", func(t *testing.T) { attestations := getAttestationsFor(t, "../test/data/sigstore-js-2.1.0_with_2_bundles.jsonl") - verifier := NewLiveSigstoreVerifier(SigstoreConfig{ + verifier, err := NewLiveSigstoreVerifier(SigstoreConfig{ Logger: io.NewTestHandler(), TrustedRoot: test.NormalizeRelativePath("../test/data/trusted_root.json"), TUFMetadataDir: o.Some(t.TempDir()), }) + require.NoError(t, err) results, err := verifier.Verify(attestations, publicGoodPolicy(t)) require.Len(t, results, 2) diff --git a/pkg/cmd/attestation/verify/attestation_integration_test.go b/pkg/cmd/attestation/verify/attestation_integration_test.go index 9ff1741413e..73452c4255b 100644 --- a/pkg/cmd/attestation/verify/attestation_integration_test.go +++ b/pkg/cmd/attestation/verify/attestation_integration_test.go @@ -25,10 +25,11 @@ func getAttestationsFor(t *testing.T, bundlePath string) []*api.Attestation { } func TestVerifyAttestations(t *testing.T) { - sgVerifier := verification.NewLiveSigstoreVerifier(verification.SigstoreConfig{ + sgVerifier, err := verification.NewLiveSigstoreVerifier(verification.SigstoreConfig{ Logger: io.NewTestHandler(), TUFMetadataDir: o.Some(t.TempDir()), }) + require.NoError(t, err) certSummary := certificate.Summary{} certSummary.SourceRepositoryOwnerURI = "https://github.com/sigstore" diff --git a/pkg/cmd/attestation/verify/verify.go b/pkg/cmd/attestation/verify/verify.go index 3affdfabb23..b3bad519aad 100644 --- a/pkg/cmd/attestation/verify/verify.go +++ b/pkg/cmd/attestation/verify/verify.go @@ -211,7 +211,11 @@ func NewVerifyCmd(f *cmdutil.Factory, runF func(*Options) error) *cobra.Command return runF(opts) } - opts.SigstoreVerifier = verification.NewLiveSigstoreVerifier(config) + sigstoreVerifier, err := verification.NewLiveSigstoreVerifier(config) + if err != nil { + return fmt.Errorf("error creating Sigstore verifier: %w", err) + } + opts.SigstoreVerifier = sigstoreVerifier opts.Config = f.Config if err := runVerify(opts); err != nil { diff --git a/pkg/cmd/attestation/verify/verify_integration_test.go b/pkg/cmd/attestation/verify/verify_integration_test.go index 09479995c1e..92864f78e64 100644 --- a/pkg/cmd/attestation/verify/verify_integration_test.go +++ b/pkg/cmd/attestation/verify/verify_integration_test.go @@ -33,6 +33,8 @@ func TestVerifyIntegration(t *testing.T) { host, _ := auth.DefaultHost() + sigstoreVerifier, err := verification.NewLiveSigstoreVerifier(sigstoreConfig) + require.NoError(t, err) publicGoodOpts := Options{ APIClient: api.NewLiveClient(hc, host, logger), ArtifactPath: artifactPath, @@ -44,7 +46,7 @@ func TestVerifyIntegration(t *testing.T) { Owner: "sigstore", PredicateType: verification.SLSAPredicateV1, SANRegex: "^https://github.com/sigstore/", - SigstoreVerifier: verification.NewLiveSigstoreVerifier(sigstoreConfig), + SigstoreVerifier: sigstoreVerifier, } t.Run("with valid owner", func(t *testing.T) { @@ -106,6 +108,8 @@ func TestVerifyIntegration(t *testing.T) { }) t.Run("with bundle from OCI registry", func(t *testing.T) { + sigstoreVerifier, err := verification.NewLiveSigstoreVerifier(sigstoreConfig) + require.NoError(t, err) opts := Options{ APIClient: api.NewLiveClient(hc, host, logger), ArtifactPath: "oci://ghcr.io/github/artifact-attestations-helm-charts/policy-controller:v0.10.0-github9", @@ -117,10 +121,10 @@ func TestVerifyIntegration(t *testing.T) { Owner: "github", PredicateType: verification.SLSAPredicateV1, SANRegex: "^https://github.com/github/", - SigstoreVerifier: verification.NewLiveSigstoreVerifier(sigstoreConfig), + SigstoreVerifier: sigstoreVerifier, } - err := runVerify(&opts) + err = runVerify(&opts) require.NoError(t, err) }) } @@ -145,6 +149,8 @@ func TestVerifyIntegrationCustomIssuer(t *testing.T) { host, _ := auth.DefaultHost() + sigstoreVerifier, err := verification.NewLiveSigstoreVerifier(sigstoreConfig) + require.NoError(t, err) baseOpts := Options{ APIClient: api.NewLiveClient(hc, host, logger), ArtifactPath: artifactPath, @@ -154,7 +160,7 @@ func TestVerifyIntegrationCustomIssuer(t *testing.T) { OCIClient: oci.NewLiveClient(), OIDCIssuer: "https://token.actions.githubusercontent.com/hammer-time", PredicateType: verification.SLSAPredicateV1, - SigstoreVerifier: verification.NewLiveSigstoreVerifier(sigstoreConfig), + SigstoreVerifier: sigstoreVerifier, } t.Run("with owner and valid workflow SAN", func(t *testing.T) { @@ -216,6 +222,8 @@ func TestVerifyIntegrationReusableWorkflow(t *testing.T) { host, _ := auth.DefaultHost() + sigstoreVerifier, err := verification.NewLiveSigstoreVerifier(sigstoreConfig) + require.NoError(t, err) baseOpts := Options{ APIClient: api.NewLiveClient(hc, host, logger), ArtifactPath: artifactPath, @@ -225,7 +233,7 @@ func TestVerifyIntegrationReusableWorkflow(t *testing.T) { OCIClient: oci.NewLiveClient(), OIDCIssuer: verification.GitHubOIDCIssuer, PredicateType: verification.SLSAPredicateV1, - SigstoreVerifier: verification.NewLiveSigstoreVerifier(sigstoreConfig), + SigstoreVerifier: sigstoreVerifier, } t.Run("with owner and valid reusable workflow SAN", func(t *testing.T) { @@ -306,6 +314,8 @@ func TestVerifyIntegrationReusableWorkflowSignerWorkflow(t *testing.T) { host, _ := auth.DefaultHost() + sigstoreVerifier, err := verification.NewLiveSigstoreVerifier(sigstoreConfig) + require.NoError(t, err) baseOpts := Options{ APIClient: api.NewLiveClient(hc, host, logger), ArtifactPath: artifactPath, @@ -318,7 +328,7 @@ func TestVerifyIntegrationReusableWorkflowSignerWorkflow(t *testing.T) { Owner: "malancas", PredicateType: verification.SLSAPredicateV1, Repo: "malancas/attest-demo", - SigstoreVerifier: verification.NewLiveSigstoreVerifier(sigstoreConfig), + SigstoreVerifier: sigstoreVerifier, } type testcase struct { diff --git a/pkg/cmd/config/list/list_test.go b/pkg/cmd/config/list/list_test.go index 2184d0f1659..27260e85783 100644 --- a/pkg/cmd/config/list/list_test.go +++ b/pkg/cmd/config/list/list_test.go @@ -101,6 +101,9 @@ func Test_listRun(t *testing.T) { http_unix_socket= browser=brave color_labels=disabled + accessible_colors=disabled + accessible_prompter=disabled + spinner=enabled `), }, } diff --git a/pkg/cmd/factory/default.go b/pkg/cmd/factory/default.go index 6286c999d0a..52837b25204 100644 --- a/pkg/cmd/factory/default.go +++ b/pkg/cmd/factory/default.go @@ -6,6 +6,7 @@ import ( "net/http" "os" "regexp" + "slices" "time" "github.com/cli/cli/v2/api" @@ -226,7 +227,7 @@ func newBrowser(f *cmdutil.Factory) browser.Browser { func newPrompter(f *cmdutil.Factory) prompter.Prompter { editor, _ := cmdutil.DetermineEditor(f.Config) io := f.IOStreams - return prompter.New(editor, io.In, io.Out, io.ErrOut) + return prompter.New(editor, io) } func configFunc() func() (gh.Config, error) { @@ -283,6 +284,26 @@ func ioStreams(f *cmdutil.Factory) *iostreams.IOStreams { io.SetNeverPrompt(true) } + falseyValues := []string{"false", "0", "no", ""} + + accessiblePrompterValue, accessiblePrompterIsSet := os.LookupEnv("GH_ACCESSIBLE_PROMPTER") + if accessiblePrompterIsSet { + if !slices.Contains(falseyValues, accessiblePrompterValue) { + io.SetAccessiblePrompterEnabled(true) + } + } else if prompt := cfg.AccessiblePrompter(""); prompt.Value == "enabled" { + io.SetAccessiblePrompterEnabled(true) + } + + ghSpinnerDisabledValue, ghSpinnerDisabledIsSet := os.LookupEnv("GH_SPINNER_DISABLED") + if ghSpinnerDisabledIsSet { + if !slices.Contains(falseyValues, ghSpinnerDisabledValue) { + io.SetSpinnerDisabled(true) + } + } else if spinnerDisabled := cfg.Spinner(""); spinnerDisabled.Value == "disabled" { + io.SetSpinnerDisabled(true) + } + // Pager precedence // 1. GH_PAGER // 2. pager from config diff --git a/pkg/cmd/factory/default_test.go b/pkg/cmd/factory/default_test.go index c0275d1dec4..d7bfe39fd8d 100644 --- a/pkg/cmd/factory/default_test.go +++ b/pkg/cmd/factory/default_test.go @@ -432,6 +432,152 @@ func Test_ioStreams_prompt(t *testing.T) { } } +func Test_ioStreams_spinnerDisabled(t *testing.T) { + tests := []struct { + name string + config gh.Config + spinnerDisabled bool + env map[string]string + }{ + { + name: "default config", + spinnerDisabled: false, + }, + { + name: "config with spinner disabled", + config: disableSpinnersConfig(), + spinnerDisabled: true, + }, + { + name: "config with spinner enabled", + config: enableSpinnersConfig(), + spinnerDisabled: false, + }, + { + name: "spinner disabled via GH_SPINNER_DISABLED env var = 0", + env: map[string]string{"GH_SPINNER_DISABLED": "0"}, + spinnerDisabled: false, + }, + { + name: "spinner disabled via GH_SPINNER_DISABLED env var = false", + env: map[string]string{"GH_SPINNER_DISABLED": "false"}, + spinnerDisabled: false, + }, + { + name: "spinner disabled via GH_SPINNER_DISABLED env var = no", + env: map[string]string{"GH_SPINNER_DISABLED": "no"}, + spinnerDisabled: false, + }, + { + name: "spinner enabled via GH_SPINNER_DISABLED env var = 1", + env: map[string]string{"GH_SPINNER_DISABLED": "1"}, + spinnerDisabled: true, + }, + { + name: "spinner enabled via GH_SPINNER_DISABLED env var = true", + env: map[string]string{"GH_SPINNER_DISABLED": "true"}, + spinnerDisabled: true, + }, + { + name: "config enabled but env disabled, respects env", + config: enableSpinnersConfig(), + env: map[string]string{"GH_SPINNER_DISABLED": "true"}, + spinnerDisabled: true, + }, + { + name: "config disabled but env enabled, respects env", + config: disableSpinnersConfig(), + env: map[string]string{"GH_SPINNER_DISABLED": "false"}, + spinnerDisabled: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for k, v := range tt.env { + t.Setenv(k, v) + } + f := New("1") + f.Config = func() (gh.Config, error) { + if tt.config == nil { + return config.NewBlankConfig(), nil + } else { + return tt.config, nil + } + } + io := ioStreams(f) + assert.Equal(t, tt.spinnerDisabled, io.GetSpinnerDisabled()) + }) + } +} + +func Test_ioStreams_accessiblePrompterEnabled(t *testing.T) { + tests := []struct { + name string + config gh.Config + accessiblePrompterEnabled bool + env map[string]string + }{ + { + name: "default config", + accessiblePrompterEnabled: false, + }, + { + name: "config with accessible prompter enabled", + config: enableAccessiblePrompterConfig(), + accessiblePrompterEnabled: true, + }, + { + name: "config with accessible prompter disabled", + config: disableAccessiblePrompterConfig(), + accessiblePrompterEnabled: false, + }, + { + name: "accessible prompter enabled via GH_ACCESSIBLE_PROMPTER env var = 1", + env: map[string]string{"GH_ACCESSIBLE_PROMPTER": "1"}, + accessiblePrompterEnabled: true, + }, + { + name: "accessible prompter enabled via GH_ACCESSIBLE_PROMPTER env var = true", + env: map[string]string{"GH_ACCESSIBLE_PROMPTER": "true"}, + accessiblePrompterEnabled: true, + }, + { + name: "accessible prompter disabled via GH_ACCESSIBLE_PROMPTER env var = 0", + env: map[string]string{"GH_ACCESSIBLE_PROMPTER": "0"}, + accessiblePrompterEnabled: false, + }, + { + name: "config disabled but env enabled, respects env", + config: disableAccessiblePrompterConfig(), + env: map[string]string{"GH_ACCESSIBLE_PROMPTER": "true"}, + accessiblePrompterEnabled: true, + }, + { + name: "config enabled but env disabled, respects env", + config: enableAccessiblePrompterConfig(), + env: map[string]string{"GH_ACCESSIBLE_PROMPTER": "false"}, + accessiblePrompterEnabled: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for k, v := range tt.env { + t.Setenv(k, v) + } + f := New("1") + f.Config = func() (gh.Config, error) { + if tt.config == nil { + return config.NewBlankConfig(), nil + } else { + return tt.config, nil + } + } + io := ioStreams(f) + assert.Equal(t, tt.accessiblePrompterEnabled, io.AccessiblePrompterEnabled()) + }) + } +} + func Test_ioStreams_colorLabels(t *testing.T) { tests := []struct { name string @@ -616,6 +762,22 @@ func disablePromptConfig() gh.Config { return config.NewFromString("prompt: disabled") } +func enableAccessiblePrompterConfig() gh.Config { + return config.NewFromString("accessible_prompter: enabled") +} + +func disableAccessiblePrompterConfig() gh.Config { + return config.NewFromString("accessible_prompter: disabled") +} + +func disableSpinnersConfig() gh.Config { + return config.NewFromString("spinner: disabled") +} + +func enableSpinnersConfig() gh.Config { + return config.NewFromString("spinner: enabled") +} + func disableColorLabelsConfig() gh.Config { return config.NewFromString("color_labels: disabled") } diff --git a/pkg/cmd/issue/argparsetest/argparsetest.go b/pkg/cmd/issue/argparsetest/argparsetest.go new file mode 100644 index 00000000000..5ae1ada8de4 --- /dev/null +++ b/pkg/cmd/issue/argparsetest/argparsetest.go @@ -0,0 +1,137 @@ +package argparsetest + +import ( + "bytes" + "reflect" + "testing" + + "github.com/cli/cli/v2/internal/ghrepo" + "github.com/cli/cli/v2/pkg/cmdutil" + "github.com/google/shlex" + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newCmdFunc represents the typical function signature we use for creating commands e.g. `NewCmdView`. +// +// It is generic over `T` as each command construction has their own Options type e.g. `ViewOptions` +type newCmdFunc[T any] func(f *cmdutil.Factory, runF func(*T) error) *cobra.Command + +// TestArgParsing is a test helper that verifies that issue commands correctly parse the `{issue number | url}` +// positional arg into an issue number and base repo. +// +// Looking through the existing tests, I noticed that the coverage was pretty smattered. +// Since nearly all issue commands only accept a single positional argument, we are able to reuse this test helper. +// Commands with no further flags or args can use this solely. +// Commands with extra flags use this and further table tests. +// Commands with extra required positional arguments (like `transfer`) cannot use this. They duplicate these cases inline. +func TestArgParsing[T any](t *testing.T, fn newCmdFunc[T]) { + tests := []struct { + name string + input string + expectedissueNumber int + expectedBaseRepo ghrepo.Interface + expectErr bool + }{ + { + name: "no argument", + input: "", + expectErr: true, + }, + { + name: "issue number argument", + input: "23 --repo owner/repo", + expectedissueNumber: 23, + expectedBaseRepo: ghrepo.New("owner", "repo"), + }, + { + name: "argument is hash prefixed number", + // Escaping is required here to avoid what I think is shellex treating it as a comment. + input: "\\#23 --repo owner/repo", + expectedissueNumber: 23, + expectedBaseRepo: ghrepo.New("owner", "repo"), + }, + { + name: "argument is a URL", + input: "https://github.com/cli/cli/issues/23", + expectedissueNumber: 23, + expectedBaseRepo: ghrepo.New("cli", "cli"), + }, + { + name: "argument cannot be parsed to an issue", + input: "unparseable", + expectErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &cmdutil.Factory{} + + argv, err := shlex.Split(tt.input) + assert.NoError(t, err) + + var gotOpts T + cmd := fn(f, func(opts *T) error { + gotOpts = *opts + return nil + }) + + cmdutil.EnableRepoOverride(cmd, f) + + // TODO: remember why we do this + cmd.Flags().BoolP("help", "x", false, "") + + cmd.SetArgs(argv) + cmd.SetIn(&bytes.Buffer{}) + cmd.SetOut(&bytes.Buffer{}) + cmd.SetErr(&bytes.Buffer{}) + + _, err = cmd.ExecuteC() + + if tt.expectErr { + require.Error(t, err) + return + } else { + require.NoError(t, err) + } + + actualIssueNumber := issueNumberFromOpts(t, gotOpts) + assert.Equal(t, tt.expectedissueNumber, actualIssueNumber) + + actualBaseRepo := baseRepoFromOpts(t, gotOpts) + assert.True( + t, + ghrepo.IsSame(tt.expectedBaseRepo, actualBaseRepo), + "expected base repo %+v, got %+v", tt.expectedBaseRepo, actualBaseRepo, + ) + }) + } +} + +func issueNumberFromOpts(t *testing.T, v any) int { + rv := reflect.ValueOf(v) + field := rv.FieldByName("IssueNumber") + if !field.IsValid() || field.Kind() != reflect.Int { + t.Fatalf("Type %T does not have IssueNumber int field", v) + } + return int(field.Int()) +} + +func baseRepoFromOpts(t *testing.T, v any) ghrepo.Interface { + rv := reflect.ValueOf(v) + field := rv.FieldByName("BaseRepo") + // check whether the field is valid and of type func() (ghrepo.Interface, error) + if !field.IsValid() || field.Kind() != reflect.Func { + t.Fatalf("Type %T does not have BaseRepo func field", v) + } + // call the function and check the return value + results := field.Call([]reflect.Value{}) + if len(results) != 2 { + t.Fatalf("%T.BaseRepo() does not return two values", v) + } + if !results[1].IsNil() { + t.Fatalf("%T.BaseRepo() returned an error: %v", v, results[1].Interface()) + } + return results[0].Interface().(ghrepo.Interface) +} diff --git a/pkg/cmd/issue/close/close.go b/pkg/cmd/issue/close/close.go index 9197abff661..21fe45dd666 100644 --- a/pkg/cmd/issue/close/close.go +++ b/pkg/cmd/issue/close/close.go @@ -21,7 +21,7 @@ type CloseOptions struct { IO *iostreams.IOStreams BaseRepo func() (ghrepo.Interface, error) - SelectorArg string + IssueNumber int Comment string Reason string @@ -39,13 +39,23 @@ func NewCmdClose(f *cmdutil.Factory, runF func(*CloseOptions) error) *cobra.Comm Short: "Close issue", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - // support `-R, --repo` override - opts.BaseRepo = f.BaseRepo + issueNumber, baseRepo, err := shared.ParseIssueFromArg(args[0]) + if err != nil { + return err + } - if len(args) > 0 { - opts.SelectorArg = args[0] + // If the args provided the base repo then use that directly. + if baseRepo, present := baseRepo.Value(); present { + opts.BaseRepo = func() (ghrepo.Interface, error) { + return baseRepo, nil + } + } else { + // support `-R, --repo` override + opts.BaseRepo = f.BaseRepo } + opts.IssueNumber = issueNumber + if runF != nil { return runF(opts) } @@ -67,7 +77,12 @@ func closeRun(opts *CloseOptions) error { return err } - issue, baseRepo, err := shared.IssueFromArgWithFields(httpClient, opts.BaseRepo, opts.SelectorArg, []string{"id", "number", "title", "state"}) + baseRepo, err := opts.BaseRepo() + if err != nil { + return err + } + + issue, err := shared.FindIssueOrPR(httpClient, baseRepo, opts.IssueNumber, []string{"id", "number", "title", "state"}) if err != nil { return err } diff --git a/pkg/cmd/issue/close/close_test.go b/pkg/cmd/issue/close/close_test.go index 4d50e56b2e0..04c39cd8da0 100644 --- a/pkg/cmd/issue/close/close_test.go +++ b/pkg/cmd/issue/close/close_test.go @@ -7,46 +7,32 @@ import ( fd "github.com/cli/cli/v2/internal/featuredetection" "github.com/cli/cli/v2/internal/ghrepo" + "github.com/cli/cli/v2/pkg/cmd/issue/argparsetest" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/httpmock" "github.com/cli/cli/v2/pkg/iostreams" "github.com/google/shlex" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewCmdClose(t *testing.T) { + // Test shared parsing of issue number / URL. + argparsetest.TestArgParsing(t, NewCmdClose) + tests := []struct { - name string - input string - output CloseOptions - wantErr bool - errMsg string + name string + input string + output CloseOptions + expectedBaseRepo ghrepo.Interface + wantErr bool + errMsg string }{ - { - name: "no argument", - input: "", - wantErr: true, - errMsg: "accepts 1 arg(s), received 0", - }, - { - name: "issue number", - input: "123", - output: CloseOptions{ - SelectorArg: "123", - }, - }, - { - name: "issue url", - input: "https://github.com/cli/cli/3", - output: CloseOptions{ - SelectorArg: "https://github.com/cli/cli/3", - }, - }, { name: "comment", input: "123 --comment 'closing comment'", output: CloseOptions{ - SelectorArg: "123", + IssueNumber: 123, Comment: "closing comment", }, }, @@ -54,7 +40,7 @@ func TestNewCmdClose(t *testing.T) { name: "reason", input: "123 --reason 'not planned'", output: CloseOptions{ - SelectorArg: "123", + IssueNumber: 123, Reason: "not planned", }, }, @@ -79,15 +65,24 @@ func TestNewCmdClose(t *testing.T) { _, err = cmd.ExecuteC() if tt.wantErr { - assert.Error(t, err) + require.Error(t, err) assert.Equal(t, tt.errMsg, err.Error()) return } - assert.NoError(t, err) - assert.Equal(t, tt.output.SelectorArg, gotOpts.SelectorArg) + require.NoError(t, err) + assert.Equal(t, tt.output.IssueNumber, gotOpts.IssueNumber) assert.Equal(t, tt.output.Comment, gotOpts.Comment) assert.Equal(t, tt.output.Reason, gotOpts.Reason) + if tt.expectedBaseRepo != nil { + baseRepo, err := gotOpts.BaseRepo() + require.NoError(t, err) + require.True( + t, + ghrepo.IsSame(tt.expectedBaseRepo, baseRepo), + "expected base repo %+v, got %+v", tt.expectedBaseRepo, baseRepo, + ) + } }) } } @@ -104,7 +99,7 @@ func TestCloseRun(t *testing.T) { { name: "close issue by number", opts: &CloseOptions{ - SelectorArg: "13", + IssueNumber: 13, }, httpStubs: func(reg *httpmock.Registry) { reg.Register( @@ -128,7 +123,7 @@ func TestCloseRun(t *testing.T) { { name: "close issue with comment", opts: &CloseOptions{ - SelectorArg: "13", + IssueNumber: 13, Comment: "closing comment", }, httpStubs: func(reg *httpmock.Registry) { @@ -164,7 +159,7 @@ func TestCloseRun(t *testing.T) { { name: "close issue with reason", opts: &CloseOptions{ - SelectorArg: "13", + IssueNumber: 13, Reason: "not planned", Detector: &fd.EnabledDetectorMock{}, }, @@ -192,7 +187,7 @@ func TestCloseRun(t *testing.T) { { name: "close issue with reason when reason is not supported", opts: &CloseOptions{ - SelectorArg: "13", + IssueNumber: 13, Reason: "not planned", Detector: &fd.DisabledDetectorMock{}, }, @@ -219,7 +214,7 @@ func TestCloseRun(t *testing.T) { { name: "issue already closed", opts: &CloseOptions{ - SelectorArg: "13", + IssueNumber: 13, }, httpStubs: func(reg *httpmock.Registry) { reg.Register( @@ -236,7 +231,7 @@ func TestCloseRun(t *testing.T) { { name: "issues disabled", opts: &CloseOptions{ - SelectorArg: "13", + IssueNumber: 13, }, httpStubs: func(reg *httpmock.Registry) { reg.Register( diff --git a/pkg/cmd/issue/comment/comment.go b/pkg/cmd/issue/comment/comment.go index 090b0748c4b..706ff791eae 100644 --- a/pkg/cmd/issue/comment/comment.go +++ b/pkg/cmd/issue/comment/comment.go @@ -3,6 +3,7 @@ package comment import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/internal/ghrepo" + "github.com/cli/cli/v2/pkg/cmd/issue/shared" issueShared "github.com/cli/cli/v2/pkg/cmd/issue/shared" prShared "github.com/cli/cli/v2/pkg/cmd/pr/shared" "github.com/cli/cli/v2/pkg/cmdutil" @@ -37,15 +38,41 @@ func NewCmdComment(f *cmdutil.Factory, runF func(*prShared.CommentableOptions) e Args: cobra.ExactArgs(1), PreRunE: func(cmd *cobra.Command, args []string) error { opts.RetrieveCommentable = func() (prShared.Commentable, ghrepo.Interface, error) { + // TODO wm: more testing + issueNumber, parsedBaseRepo, err := shared.ParseIssueFromArg(args[0]) + if err != nil { + return nil, nil, err + } + + // If the args provided the base repo then use that directly. + var baseRepo ghrepo.Interface + + if parsedBaseRepo, present := parsedBaseRepo.Value(); present { + baseRepo = parsedBaseRepo + } else { + // support `-R, --repo` override + baseRepo, err = f.BaseRepo() + if err != nil { + return nil, nil, err + } + } + httpClient, err := f.HttpClient() if err != nil { return nil, nil, err } + fields := []string{"id", "url"} if opts.EditLast { fields = append(fields, "comments") } - return issueShared.IssueFromArgWithFields(httpClient, f.BaseRepo, args[0], fields) + + issue, err := issueShared.FindIssueOrPR(httpClient, baseRepo, issueNumber, fields) + if err != nil { + return nil, nil, err + } + + return issue, baseRepo, nil } return prShared.CommentablePreRun(cmd, opts) }, diff --git a/pkg/cmd/issue/create/create.go b/pkg/cmd/issue/create/create.go index 2e3e0de519a..2978a21fc7b 100644 --- a/pkg/cmd/issue/create/create.go +++ b/pkg/cmd/issue/create/create.go @@ -4,10 +4,12 @@ import ( "errors" "fmt" "net/http" + "time" "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/api" "github.com/cli/cli/v2/internal/browser" + fd "github.com/cli/cli/v2/internal/featuredetection" "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/text" @@ -24,6 +26,7 @@ type CreateOptions struct { BaseRepo func() (ghrepo.Interface, error) Browser browser.Browser Prompter prShared.Prompt + Detector fd.Detector TitledEditSurvey func(string, string) (string, string, error) RootDirOverride string @@ -46,11 +49,12 @@ type CreateOptions struct { func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Command { opts := &CreateOptions{ - IO: f.IOStreams, - HttpClient: f.HttpClient, - Config: f.Config, - Browser: f.Browser, - Prompter: f.Prompter, + IO: f.IOStreams, + HttpClient: f.HttpClient, + Config: f.Config, + Browser: f.Browser, + Prompter: f.Prompter, + TitledEditSurvey: prShared.TitledEditSurvey(&prShared.UserEditor{Config: f.Config, IO: f.IOStreams}), } @@ -146,6 +150,15 @@ func createRun(opts *CreateOptions) (err error) { return } + // TODO projectsV1Deprecation + // Remove this section as we should no longer need to detect + if opts.Detector == nil { + cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24) + opts.Detector = fd.NewDetector(cachedClient, baseRepo.RepoHost()) + } + + projectsV1Support := opts.Detector.ProjectsV1() + isTerminal := opts.IO.IsStdoutTTY() var milestones []string @@ -160,13 +173,13 @@ func createRun(opts *CreateOptions) (err error) { } tb := prShared.IssueMetadataState{ - Type: prShared.IssueMetadata, - Assignees: assignees, - Labels: opts.Labels, - Projects: opts.Projects, - Milestones: milestones, - Title: opts.Title, - Body: opts.Body, + Type: prShared.IssueMetadata, + Assignees: assignees, + Labels: opts.Labels, + ProjectTitles: opts.Projects, + Milestones: milestones, + Title: opts.Title, + Body: opts.Body, } if opts.RecoverFile != "" { @@ -182,7 +195,7 @@ func createRun(opts *CreateOptions) (err error) { if opts.WebMode { var openURL string if opts.Title != "" || opts.Body != "" || tb.HasMetadata() { - openURL, err = generatePreviewURL(apiClient, baseRepo, tb) + openURL, err = generatePreviewURL(apiClient, baseRepo, tb, projectsV1Support) if err != nil { return } @@ -260,7 +273,7 @@ func createRun(opts *CreateOptions) (err error) { } } - openURL, err = generatePreviewURL(apiClient, baseRepo, tb) + openURL, err = generatePreviewURL(apiClient, baseRepo, tb, projectsV1Support) if err != nil { return } @@ -279,7 +292,7 @@ func createRun(opts *CreateOptions) (err error) { Repo: baseRepo, State: &tb, } - err = prShared.MetadataSurvey(opts.Prompter, opts.IO, baseRepo, fetcher, &tb) + err = prShared.MetadataSurvey(opts.Prompter, opts.IO, baseRepo, fetcher, &tb, projectsV1Support) if err != nil { return } @@ -335,7 +348,7 @@ func createRun(opts *CreateOptions) (err error) { params["issueTemplate"] = templateNameForSubmit } - err = prShared.AddMetadataToIssueParams(apiClient, baseRepo, params, &tb) + err = prShared.AddMetadataToIssueParams(apiClient, baseRepo, params, &tb, projectsV1Support) if err != nil { return } @@ -354,7 +367,7 @@ func createRun(opts *CreateOptions) (err error) { return } -func generatePreviewURL(apiClient *api.Client, baseRepo ghrepo.Interface, tb prShared.IssueMetadataState) (string, error) { +func generatePreviewURL(apiClient *api.Client, baseRepo ghrepo.Interface, tb prShared.IssueMetadataState, projectsV1Support gh.ProjectsV1Support) (string, error) { openURL := ghrepo.GenerateRepoURL(baseRepo, "issues/new") - return prShared.WithPrAndIssueQueryParams(apiClient, baseRepo, openURL, tb) + return prShared.WithPrAndIssueQueryParams(apiClient, baseRepo, openURL, tb, projectsV1Support) } diff --git a/pkg/cmd/issue/create/create_test.go b/pkg/cmd/issue/create/create_test.go index 8e49700a012..1211c0c1d93 100644 --- a/pkg/cmd/issue/create/create_test.go +++ b/pkg/cmd/issue/create/create_test.go @@ -14,6 +14,7 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/internal/browser" "github.com/cli/cli/v2/internal/config" + fd "github.com/cli/cli/v2/internal/featuredetection" "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/prompter" @@ -473,6 +474,7 @@ func Test_createRun(t *testing.T) { opts.BaseRepo = func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil } + opts.Detector = &fd.EnabledDetectorMock{} browser := &browser.Stub{} opts.Browser = browser @@ -521,6 +523,7 @@ func runCommandWithRootDirOverridden(rt http.RoundTripper, isTTY bool, cli strin cmd := NewCmdCreate(factory, func(opts *CreateOptions) error { opts.RootDirOverride = rootDir + opts.Detector = &fd.EnabledDetectorMock{} return createRun(opts) }) @@ -1026,3 +1029,146 @@ func TestIssueCreate_projectsV2(t *testing.T) { assert.Equal(t, "https://github.com/OWNER/REPO/issues/12\n", output.String()) } + +// TODO projectsV1Deprecation +// Remove this test. +func TestProjectsV1Deprecation(t *testing.T) { + + t.Run("non-interactive submission", func(t *testing.T) { + t.Run("when projects v1 is supported, queries for it", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + reg.StubRepoInfoResponse("OWNER", "REPO", "main") + reg.Register( + // ( is required to avoid matching projectsV2 + httpmock.GraphQL(`projects\(`), + // Simulate a GraphQL error to early exit the test. + httpmock.StatusStringResponse(500, ""), + ) + + _, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + // Ignore the error because we have no way to really stub it without + // fully stubbing a GQL error structure in the request body. + _ = createRun(&CreateOptions{ + IO: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + }, + + Detector: &fd.EnabledDetectorMock{}, + Title: "Test Title", + Body: "Test Body", + // Required to force a lookup of projects + Projects: []string{"Project"}, + }) + + // Verify that our request contained projects + reg.Verify(t) + }) + + t.Run("when projects v1 is not supported, does not query for it", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + reg.StubRepoInfoResponse("OWNER", "REPO", "main") + // ( is required to avoid matching projectsV2 + reg.Exclude(t, httpmock.GraphQL(`projects\(`)) + + _, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + // Ignore the error because we're not really interested in it. + _ = createRun(&CreateOptions{ + IO: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + }, + + Detector: &fd.DisabledDetectorMock{}, + Title: "Test Title", + Body: "Test Body", + // Required to force a lookup of projects + Projects: []string{"Project"}, + }) + + // Verify that our request contained projectCards + reg.Verify(t) + }) + }) + + t.Run("web mode", func(t *testing.T) { + t.Run("when projects v1 is supported, queries for it", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + reg.Register( + // ( is required to avoid matching projectsV2 + httpmock.GraphQL(`projects\(`), + // Simulate a GraphQL error to early exit the test. + httpmock.StatusStringResponse(500, ""), + ) + + _, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + // Ignore the error because we have no way to really stub it without + // fully stubbing a GQL error structure in the request body. + _ = createRun(&CreateOptions{ + IO: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + }, + + Detector: &fd.EnabledDetectorMock{}, + WebMode: true, + // Required to force a lookup of projects + Projects: []string{"Project"}, + }) + + // Verify that our request contained projects + reg.Verify(t) + }) + + t.Run("when projects v1 is not supported, does not query for it", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + // ( is required to avoid matching projectsV2 + reg.Exclude(t, httpmock.GraphQL(`projects\(`)) + + _, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + // Ignore the error because we're not really interested in it. + _ = createRun(&CreateOptions{ + IO: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + }, + + Detector: &fd.DisabledDetectorMock{}, + WebMode: true, + // Required to force a lookup of projects + Projects: []string{"Project"}, + }) + + // Verify that our request contained projectCards + reg.Verify(t) + }) + }) +} diff --git a/pkg/cmd/issue/delete/delete.go b/pkg/cmd/issue/delete/delete.go index fb41f288e56..269ef7081a7 100644 --- a/pkg/cmd/issue/delete/delete.go +++ b/pkg/cmd/issue/delete/delete.go @@ -21,7 +21,7 @@ type DeleteOptions struct { BaseRepo func() (ghrepo.Interface, error) Prompter iprompter - SelectorArg string + IssueNumber int Confirmed bool } @@ -42,13 +42,23 @@ func NewCmdDelete(f *cmdutil.Factory, runF func(*DeleteOptions) error) *cobra.Co Short: "Delete issue", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - // support `-R, --repo` override - opts.BaseRepo = f.BaseRepo + issueNumber, baseRepo, err := shared.ParseIssueFromArg(args[0]) + if err != nil { + return err + } - if len(args) > 0 { - opts.SelectorArg = args[0] + // If the args provided the base repo then use that directly. + if baseRepo, present := baseRepo.Value(); present { + opts.BaseRepo = func() (ghrepo.Interface, error) { + return baseRepo, nil + } + } else { + // support `-R, --repo` override + opts.BaseRepo = f.BaseRepo } + opts.IssueNumber = issueNumber + if runF != nil { return runF(opts) } @@ -71,7 +81,12 @@ func deleteRun(opts *DeleteOptions) error { return err } - issue, baseRepo, err := shared.IssueFromArgWithFields(httpClient, opts.BaseRepo, opts.SelectorArg, []string{"id", "number", "title"}) + baseRepo, err := opts.BaseRepo() + if err != nil { + return err + } + + issue, err := shared.FindIssueOrPR(httpClient, baseRepo, opts.IssueNumber, []string{"id", "number", "title"}) if err != nil { return err } diff --git a/pkg/cmd/issue/delete/delete_test.go b/pkg/cmd/issue/delete/delete_test.go index bd83c826f01..64522b1d3f7 100644 --- a/pkg/cmd/issue/delete/delete_test.go +++ b/pkg/cmd/issue/delete/delete_test.go @@ -12,6 +12,7 @@ import ( "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/prompter" + "github.com/cli/cli/v2/pkg/cmd/issue/argparsetest" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/httpmock" "github.com/cli/cli/v2/pkg/iostreams" @@ -20,6 +21,10 @@ import ( "github.com/stretchr/testify/assert" ) +func TestNewCmdDelete(t *testing.T) { + argparsetest.TestArgParsing(t, NewCmdDelete) +} + func runCommand(rt http.RoundTripper, pm *prompter.MockPrompter, isTTY bool, cli string) (*test.CmdOut, error) { ios, _, stdout, stderr := iostreams.Test() ios.SetStdoutTTY(isTTY) diff --git a/pkg/cmd/issue/develop/develop.go b/pkg/cmd/issue/develop/develop.go index 1536800f072..19c9b5fa903 100644 --- a/pkg/cmd/issue/develop/develop.go +++ b/pkg/cmd/issue/develop/develop.go @@ -24,12 +24,12 @@ type DevelopOptions struct { BaseRepo func() (ghrepo.Interface, error) Remotes func() (context.Remotes, error) - IssueSelector string - Name string - BranchRepo string - BaseBranch string - Checkout bool - List bool + IssueNumber int + Name string + BranchRepo string + BaseBranch string + Checkout bool + List bool } func NewCmdDevelop(f *cmdutil.Factory, runF func(*DevelopOptions) error) *cobra.Command { @@ -89,9 +89,23 @@ func NewCmdDevelop(f *cmdutil.Factory, runF func(*DevelopOptions) error) *cobra. return nil }, RunE: func(cmd *cobra.Command, args []string) error { - // support `-R, --repo` override - opts.BaseRepo = f.BaseRepo - opts.IssueSelector = args[0] + issueNumber, baseRepo, err := shared.ParseIssueFromArg(args[0]) + if err != nil { + return err + } + + // If the args provided the base repo then use that directly. + if baseRepo, present := baseRepo.Value(); present { + opts.BaseRepo = func() (ghrepo.Interface, error) { + return baseRepo, nil + } + } else { + // support `-R, --repo` override + opts.BaseRepo = f.BaseRepo + } + + opts.IssueNumber = issueNumber + if err := cmdutil.MutuallyExclusive("specify only one of `--list` or `--branch-repo`", opts.List, opts.BranchRepo != ""); err != nil { return err } @@ -131,8 +145,13 @@ func developRun(opts *DevelopOptions) error { return err } + baseRepo, err := opts.BaseRepo() + if err != nil { + return err + } + opts.IO.StartProgressIndicator() - issue, issueRepo, err := shared.IssueFromArgWithFields(httpClient, opts.BaseRepo, opts.IssueSelector, []string{"id", "number"}) + issue, err := shared.FindIssueOrPR(httpClient, baseRepo, opts.IssueNumber, []string{"id", "number"}) opts.IO.StopProgressIndicator() if err != nil { return err @@ -141,16 +160,16 @@ func developRun(opts *DevelopOptions) error { apiClient := api.NewClientFromHTTP(httpClient) opts.IO.StartProgressIndicator() - err = api.CheckLinkedBranchFeature(apiClient, issueRepo.RepoHost()) + err = api.CheckLinkedBranchFeature(apiClient, baseRepo.RepoHost()) opts.IO.StopProgressIndicator() if err != nil { return err } if opts.List { - return developRunList(opts, apiClient, issueRepo, issue) + return developRunList(opts, apiClient, baseRepo, issue) } - return developRunCreate(opts, apiClient, issueRepo, issue) + return developRunCreate(opts, apiClient, baseRepo, issue) } func developRunCreate(opts *DevelopOptions, apiClient *api.Client, issueRepo ghrepo.Interface, issue *api.Issue) error { diff --git a/pkg/cmd/issue/develop/develop_test.go b/pkg/cmd/issue/develop/develop_test.go index 831f03fc3bd..2485c8cc4cf 100644 --- a/pkg/cmd/issue/develop/develop_test.go +++ b/pkg/cmd/issue/develop/develop_test.go @@ -11,89 +11,74 @@ import ( "github.com/cli/cli/v2/git" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/run" + "github.com/cli/cli/v2/pkg/cmd/issue/argparsetest" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/httpmock" "github.com/cli/cli/v2/pkg/iostreams" "github.com/google/shlex" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewCmdDevelop(t *testing.T) { + // Test shared parsing of issue number / URL. + argparsetest.TestArgParsing(t, NewCmdDevelop) + tests := []struct { - name string - input string - output DevelopOptions - wantStdout string - wantStderr string - wantErr bool - errMsg string + name string + input string + output DevelopOptions + expectedBaseRepo ghrepo.Interface + wantStdout string + wantStderr string + wantErr bool + errMsg string }{ - { - name: "no argument", - input: "", - output: DevelopOptions{}, - wantErr: true, - errMsg: "issue number or url is required", - }, - { - name: "issue number", - input: "1", - output: DevelopOptions{ - IssueSelector: "1", - }, - }, - { - name: "issue url", - input: "https://github.com/cli/cli/issues/1", - output: DevelopOptions{ - IssueSelector: "https://github.com/cli/cli/issues/1", - }, - }, { name: "branch-repo flag", input: "1 --branch-repo owner/repo", output: DevelopOptions{ - IssueSelector: "1", - BranchRepo: "owner/repo", + IssueNumber: 1, + BranchRepo: "owner/repo", }, }, { name: "base flag", input: "1 --base feature", output: DevelopOptions{ - IssueSelector: "1", - BaseBranch: "feature", + IssueNumber: 1, + BaseBranch: "feature", }, }, { name: "checkout flag", input: "1 --checkout", output: DevelopOptions{ - IssueSelector: "1", - Checkout: true, + IssueNumber: 1, + Checkout: true, }, }, { name: "list flag", input: "1 --list", output: DevelopOptions{ - IssueSelector: "1", - List: true, + IssueNumber: 1, + List: true, }, }, { name: "name flag", input: "1 --name feature", output: DevelopOptions{ - IssueSelector: "1", - Name: "feature", + IssueNumber: 1, + Name: "feature", }, }, { name: "issue-repo flag", input: "1 --issue-repo cli/cli", output: DevelopOptions{ - IssueSelector: "1", + IssueNumber: 1, }, wantStdout: "Flag --issue-repo has been deprecated, use `--repo` instead\n", }, @@ -143,18 +128,27 @@ func TestNewCmdDevelop(t *testing.T) { _, err = cmd.ExecuteC() if tt.wantErr { - assert.EqualError(t, err, tt.errMsg) + require.EqualError(t, err, tt.errMsg) return } - assert.NoError(t, err) - assert.Equal(t, tt.output.IssueSelector, gotOpts.IssueSelector) + require.NoError(t, err) + assert.Equal(t, tt.output.IssueNumber, gotOpts.IssueNumber) assert.Equal(t, tt.output.Name, gotOpts.Name) assert.Equal(t, tt.output.BaseBranch, gotOpts.BaseBranch) assert.Equal(t, tt.output.Checkout, gotOpts.Checkout) assert.Equal(t, tt.output.List, gotOpts.List) assert.Equal(t, tt.wantStdout, stdOut.String()) assert.Equal(t, tt.wantStderr, stdErr.String()) + if tt.expectedBaseRepo != nil { + baseRepo, err := gotOpts.BaseRepo() + require.NoError(t, err) + require.True( + t, + ghrepo.IsSame(tt.expectedBaseRepo, baseRepo), + "expected base repo %+v, got %+v", tt.expectedBaseRepo, baseRepo, + ) + } }) } } @@ -178,8 +172,8 @@ func TestDevelopRun(t *testing.T) { { name: "returns an error when the feature is not supported by the API", opts: &DevelopOptions{ - IssueSelector: "42", - List: true, + IssueNumber: 42, + List: true, }, httpStubs: func(reg *httpmock.Registry, t *testing.T) { reg.Register( @@ -196,8 +190,8 @@ func TestDevelopRun(t *testing.T) { { name: "list branches for an issue", opts: &DevelopOptions{ - IssueSelector: "42", - List: true, + IssueNumber: 42, + List: true, }, httpStubs: func(reg *httpmock.Registry, t *testing.T) { reg.Register( @@ -223,8 +217,8 @@ func TestDevelopRun(t *testing.T) { { name: "list branches for an issue in tty", opts: &DevelopOptions{ - IssueSelector: "42", - List: true, + IssueNumber: 42, + List: true, }, tty: true, httpStubs: func(reg *httpmock.Registry, t *testing.T) { @@ -255,37 +249,10 @@ func TestDevelopRun(t *testing.T) { bar https://github.com/OWNER/OTHER-REPO/tree/bar `), }, - { - name: "list branches for an issue providing an issue url", - opts: &DevelopOptions{ - IssueSelector: "https://github.com/cli/cli/issues/42", - List: true, - }, - httpStubs: func(reg *httpmock.Registry, t *testing.T) { - reg.Register( - httpmock.GraphQL(`query IssueByNumber\b`), - httpmock.StringResponse(`{"data":{"repository":{"hasIssuesEnabled":true,"issue":{"id":"SOMEID","number":42}}}}`), - ) - reg.Register( - httpmock.GraphQL(`query LinkedBranchFeature\b`), - httpmock.StringResponse(featureEnabledPayload), - ) - reg.Register( - httpmock.GraphQL(`query ListLinkedBranches\b`), - httpmock.GraphQLQuery(` - {"data":{"repository":{"issue":{"linkedBranches":{"nodes":[{"ref":{"name":"foo","repository":{"url":"https://github.com/OWNER/REPO"}}},{"ref":{"name":"bar","repository":{"url":"https://github.com/OWNER/OTHER-REPO"}}}]}}}}} - `, func(query string, inputs map[string]interface{}) { - assert.Equal(t, float64(42), inputs["number"]) - assert.Equal(t, "cli", inputs["owner"]) - assert.Equal(t, "cli", inputs["name"]) - })) - }, - expectedOut: "foo\thttps://github.com/OWNER/REPO/tree/foo\nbar\thttps://github.com/OWNER/OTHER-REPO/tree/bar\n", - }, { name: "develop new branch", opts: &DevelopOptions{ - IssueSelector: "123", + IssueNumber: 123, }, remotes: map[string]string{ "origin": "OWNER/REPO", @@ -321,8 +288,8 @@ func TestDevelopRun(t *testing.T) { { name: "develop new branch in different repo than issue", opts: &DevelopOptions{ - IssueSelector: "123", - BranchRepo: "OWNER2/REPO", + IssueNumber: 123, + BranchRepo: "OWNER2/REPO", }, remotes: map[string]string{ "origin": "OWNER2/REPO", @@ -367,9 +334,9 @@ func TestDevelopRun(t *testing.T) { { name: "develop new branch with name and base specified", opts: &DevelopOptions{ - Name: "my-branch", - BaseBranch: "main", - IssueSelector: "123", + Name: "my-branch", + BaseBranch: "main", + IssueNumber: 123, }, remotes: map[string]string{ "origin": "OWNER/REPO", @@ -406,7 +373,10 @@ func TestDevelopRun(t *testing.T) { { name: "develop new branch outside of local git repo", opts: &DevelopOptions{ - IssueSelector: "https://github.com/cli/cli/issues/123", + IssueNumber: 123, + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("cli", "cli"), nil + }, }, httpStubs: func(reg *httpmock.Registry, t *testing.T) { reg.Register( @@ -436,9 +406,9 @@ func TestDevelopRun(t *testing.T) { { name: "develop new branch with checkout when local branch exists", opts: &DevelopOptions{ - Name: "my-branch", - IssueSelector: "123", - Checkout: true, + Name: "my-branch", + IssueNumber: 123, + Checkout: true, }, remotes: map[string]string{ "origin": "OWNER/REPO", @@ -478,9 +448,9 @@ func TestDevelopRun(t *testing.T) { { name: "develop new branch with checkout when local branch does not exist", opts: &DevelopOptions{ - Name: "my-branch", - IssueSelector: "123", - Checkout: true, + Name: "my-branch", + IssueNumber: 123, + Checkout: true, }, remotes: map[string]string{ "origin": "OWNER/REPO", @@ -519,8 +489,8 @@ func TestDevelopRun(t *testing.T) { { name: "develop with base branch which does not exist", opts: &DevelopOptions{ - IssueSelector: "123", - BaseBranch: "does-not-exist-branch", + IssueNumber: 123, + BaseBranch: "does-not-exist-branch", }, remotes: map[string]string{ "origin": "OWNER/REPO", @@ -561,8 +531,10 @@ func TestDevelopRun(t *testing.T) { ios.SetStderrTTY(tt.tty) opts.IO = ios - opts.BaseRepo = func() (ghrepo.Interface, error) { - return ghrepo.New("OWNER", "REPO"), nil + if opts.BaseRepo == nil { + opts.BaseRepo = func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + } } opts.Remotes = func() (context.Remotes, error) { diff --git a/pkg/cmd/issue/edit/edit.go b/pkg/cmd/issue/edit/edit.go index 18067319f40..8386cbcfa24 100644 --- a/pkg/cmd/issue/edit/edit.go +++ b/pkg/cmd/issue/edit/edit.go @@ -5,9 +5,12 @@ import ( "net/http" "sort" "sync" + "time" "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/api" + fd "github.com/cli/cli/v2/internal/featuredetection" + "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/text" shared "github.com/cli/cli/v2/pkg/cmd/issue/shared" @@ -22,13 +25,14 @@ type EditOptions struct { IO *iostreams.IOStreams BaseRepo func() (ghrepo.Interface, error) Prompter prShared.EditPrompter + Detector fd.Detector DetermineEditor func() (string, error) FieldsToEditSurvey func(prShared.EditPrompter, *prShared.Editable) error EditFieldsSurvey func(prShared.EditPrompter, *prShared.Editable, string) error FetchOptions func(*api.Client, ghrepo.Interface, *prShared.Editable) error - SelectorArgs []string + IssueNumbers []int Interactive bool prShared.Editable @@ -69,10 +73,22 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Comman `), Args: cobra.MinimumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - // support `-R, --repo` override - opts.BaseRepo = f.BaseRepo + issueNumbers, baseRepo, err := shared.ParseIssuesFromArgs(args) + if err != nil { + return err + } - opts.SelectorArgs = args + // If the args provided the base repo then use that directly. + if baseRepo, present := baseRepo.Value(); present { + opts.BaseRepo = func() (ghrepo.Interface, error) { + return baseRepo, nil + } + } else { + // support `-R, --repo` override + opts.BaseRepo = f.BaseRepo + } + + opts.IssueNumbers = issueNumbers flags := cmd.Flags() @@ -134,7 +150,7 @@ func NewCmdEdit(f *cmdutil.Factory, runF func(*EditOptions) error) *cobra.Comman return cmdutil.FlagErrorf("field to edit flag required when not running interactively") } - if opts.Interactive && len(opts.SelectorArgs) > 1 { + if opts.Interactive && len(opts.IssueNumbers) > 1 { return cmdutil.FlagErrorf("multiple issues cannot be edited interactively") } @@ -167,6 +183,11 @@ func editRun(opts *EditOptions) error { return err } + baseRepo, err := opts.BaseRepo() + if err != nil { + return err + } + // Prompt the user which fields they'd like to edit. editable := opts.Editable if opts.Interactive { @@ -184,7 +205,18 @@ func editRun(opts *EditOptions) error { lookupFields = append(lookupFields, "labels") } if editable.Projects.Edited { - lookupFields = append(lookupFields, "projectCards") + // TODO projectsV1Deprecation + // Remove this section as we should no longer add projectCards + if opts.Detector == nil { + cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24) + opts.Detector = fd.NewDetector(cachedClient, baseRepo.RepoHost()) + } + + projectsV1Support := opts.Detector.ProjectsV1() + if projectsV1Support == gh.ProjectsV1Supported { + lookupFields = append(lookupFields, "projectCards") + } + lookupFields = append(lookupFields, "projectItems") } if editable.Milestone.Edited { @@ -192,7 +224,7 @@ func editRun(opts *EditOptions) error { } // Get all specified issues and make sure they are within the same repo. - issues, repo, err := shared.IssuesFromArgsWithFields(httpClient, opts.BaseRepo, opts.SelectorArgs, lookupFields) + issues, err := shared.FindIssuesOrPRs(httpClient, baseRepo, opts.IssueNumbers, lookupFields) if err != nil { return err } @@ -200,7 +232,7 @@ func editRun(opts *EditOptions) error { // Fetch editable shared fields once for all issues. apiClient := api.NewClientFromHTTP(httpClient) opts.IO.StartProgressIndicatorWithLabel("Fetching repository information") - err = opts.FetchOptions(apiClient, repo, &editable) + err = opts.FetchOptions(apiClient, baseRepo, &editable) opts.IO.StopProgressIndicator() if err != nil { return err @@ -250,7 +282,7 @@ func editRun(opts *EditOptions) error { go func(issue *api.Issue) { defer g.Done() - err := prShared.UpdateIssue(httpClient, repo, issue.ID, issue.IsPullRequest(), editable) + err := prShared.UpdateIssue(httpClient, baseRepo, issue.ID, issue.IsPullRequest(), editable) if err != nil { failedIssueChan <- fmt.Sprintf("failed to update %s: %s", issue.URL, err) return diff --git a/pkg/cmd/issue/edit/edit_test.go b/pkg/cmd/issue/edit/edit_test.go index 40fe6491cbe..c9aa4c409f4 100644 --- a/pkg/cmd/issue/edit/edit_test.go +++ b/pkg/cmd/issue/edit/edit_test.go @@ -10,7 +10,9 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/api" + fd "github.com/cli/cli/v2/internal/featuredetection" "github.com/cli/cli/v2/internal/ghrepo" + "github.com/cli/cli/v2/internal/run" prShared "github.com/cli/cli/v2/pkg/cmd/pr/shared" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/httpmock" @@ -26,11 +28,12 @@ func TestNewCmdEdit(t *testing.T) { require.NoError(t, err) tests := []struct { - name string - input string - stdin string - output EditOptions - wantsErr bool + name string + input string + stdin string + output EditOptions + expectedBaseRepo ghrepo.Interface + wantsErr bool }{ { name: "no argument", @@ -42,7 +45,7 @@ func TestNewCmdEdit(t *testing.T) { name: "issue number argument", input: "23", output: EditOptions{ - SelectorArgs: []string{"23"}, + IssueNumbers: []int{23}, Interactive: true, }, wantsErr: false, @@ -51,7 +54,7 @@ func TestNewCmdEdit(t *testing.T) { name: "title flag", input: "23 --title test", output: EditOptions{ - SelectorArgs: []string{"23"}, + IssueNumbers: []int{23}, Editable: prShared.Editable{ Title: prShared.EditableString{ Value: "test", @@ -65,7 +68,7 @@ func TestNewCmdEdit(t *testing.T) { name: "body flag", input: "23 --body test", output: EditOptions{ - SelectorArgs: []string{"23"}, + IssueNumbers: []int{23}, Editable: prShared.Editable{ Body: prShared.EditableString{ Value: "test", @@ -80,7 +83,7 @@ func TestNewCmdEdit(t *testing.T) { input: "23 --body-file -", stdin: "this is on standard input", output: EditOptions{ - SelectorArgs: []string{"23"}, + IssueNumbers: []int{23}, Editable: prShared.Editable{ Body: prShared.EditableString{ Value: "this is on standard input", @@ -94,7 +97,7 @@ func TestNewCmdEdit(t *testing.T) { name: "body from file", input: fmt.Sprintf("23 --body-file '%s'", tmpFile), output: EditOptions{ - SelectorArgs: []string{"23"}, + IssueNumbers: []int{23}, Editable: prShared.Editable{ Body: prShared.EditableString{ Value: "a body from file", @@ -113,7 +116,7 @@ func TestNewCmdEdit(t *testing.T) { name: "add-assignee flag", input: "23 --add-assignee monalisa,hubot", output: EditOptions{ - SelectorArgs: []string{"23"}, + IssueNumbers: []int{23}, Editable: prShared.Editable{ Assignees: prShared.EditableSlice{ Add: []string{"monalisa", "hubot"}, @@ -127,7 +130,7 @@ func TestNewCmdEdit(t *testing.T) { name: "remove-assignee flag", input: "23 --remove-assignee monalisa,hubot", output: EditOptions{ - SelectorArgs: []string{"23"}, + IssueNumbers: []int{23}, Editable: prShared.Editable{ Assignees: prShared.EditableSlice{ Remove: []string{"monalisa", "hubot"}, @@ -141,7 +144,7 @@ func TestNewCmdEdit(t *testing.T) { name: "add-label flag", input: "23 --add-label feature,TODO,bug", output: EditOptions{ - SelectorArgs: []string{"23"}, + IssueNumbers: []int{23}, Editable: prShared.Editable{ Labels: prShared.EditableSlice{ Add: []string{"feature", "TODO", "bug"}, @@ -155,7 +158,7 @@ func TestNewCmdEdit(t *testing.T) { name: "remove-label flag", input: "23 --remove-label feature,TODO,bug", output: EditOptions{ - SelectorArgs: []string{"23"}, + IssueNumbers: []int{23}, Editable: prShared.Editable{ Labels: prShared.EditableSlice{ Remove: []string{"feature", "TODO", "bug"}, @@ -169,7 +172,7 @@ func TestNewCmdEdit(t *testing.T) { name: "add-project flag", input: "23 --add-project Cleanup,Roadmap", output: EditOptions{ - SelectorArgs: []string{"23"}, + IssueNumbers: []int{23}, Editable: prShared.Editable{ Projects: prShared.EditableProjects{ EditableSlice: prShared.EditableSlice{ @@ -185,7 +188,7 @@ func TestNewCmdEdit(t *testing.T) { name: "remove-project flag", input: "23 --remove-project Cleanup,Roadmap", output: EditOptions{ - SelectorArgs: []string{"23"}, + IssueNumbers: []int{23}, Editable: prShared.Editable{ Projects: prShared.EditableProjects{ EditableSlice: prShared.EditableSlice{ @@ -201,7 +204,7 @@ func TestNewCmdEdit(t *testing.T) { name: "milestone flag", input: "23 --milestone GA", output: EditOptions{ - SelectorArgs: []string{"23"}, + IssueNumbers: []int{23}, Editable: prShared.Editable{ Milestone: prShared.EditableString{ Value: "GA", @@ -215,7 +218,7 @@ func TestNewCmdEdit(t *testing.T) { name: "remove-milestone flag", input: "23 --remove-milestone", output: EditOptions{ - SelectorArgs: []string{"23"}, + IssueNumbers: []int{23}, Editable: prShared.Editable{ Milestone: prShared.EditableString{ Value: "", @@ -234,7 +237,7 @@ func TestNewCmdEdit(t *testing.T) { name: "add label to multiple issues", input: "23 34 --add-label bug", output: EditOptions{ - SelectorArgs: []string{"23", "34"}, + IssueNumbers: []int{23, 34}, Editable: prShared.Editable{ Labels: prShared.EditableSlice{ Add: []string{"bug"}, @@ -244,6 +247,31 @@ func TestNewCmdEdit(t *testing.T) { }, wantsErr: false, }, + { + name: "argument is hash prefixed number", + // Escaping is required here to avoid what I think is shellex treating it as a comment. + input: "\\#23", + output: EditOptions{ + IssueNumbers: []int{23}, + Interactive: true, + }, + wantsErr: false, + }, + { + name: "argument is a URL", + input: "https://github.com/cli/cli/issues/23", + output: EditOptions{ + IssueNumbers: []int{23}, + Interactive: true, + }, + expectedBaseRepo: ghrepo.New("cli", "cli"), + wantsErr: false, + }, + { + name: "URL arguments parse as different repos", + input: "https://github.com/cli/cli/issues/23 https://github.com/cli/go-gh/issues/23", + wantsErr: true, + }, { name: "interactive multiple issues", input: "23 34", @@ -282,14 +310,23 @@ func TestNewCmdEdit(t *testing.T) { _, err = cmd.ExecuteC() if tt.wantsErr { - assert.Error(t, err) + require.Error(t, err) return } - assert.NoError(t, err) - assert.Equal(t, tt.output.SelectorArgs, gotOpts.SelectorArgs) + require.NoError(t, err) + assert.Equal(t, tt.output.IssueNumbers, gotOpts.IssueNumbers) assert.Equal(t, tt.output.Interactive, gotOpts.Interactive) assert.Equal(t, tt.output.Editable, gotOpts.Editable) + if tt.expectedBaseRepo != nil { + baseRepo, err := gotOpts.BaseRepo() + require.NoError(t, err) + require.True( + t, + ghrepo.IsSame(tt.expectedBaseRepo, baseRepo), + "expected base repo %+v, got %+v", tt.expectedBaseRepo, baseRepo, + ) + } }) } } @@ -306,7 +343,7 @@ func Test_editRun(t *testing.T) { { name: "non-interactive", input: &EditOptions{ - SelectorArgs: []string{"123"}, + IssueNumbers: []int{123}, Interactive: false, Editable: prShared.Editable{ Title: prShared.EditableString{ @@ -359,7 +396,7 @@ func Test_editRun(t *testing.T) { { name: "non-interactive multiple issues", input: &EditOptions{ - SelectorArgs: []string{"456", "123"}, + IssueNumbers: []int{456, 123}, Interactive: false, Editable: prShared.Editable{ Assignees: prShared.EditableSlice{ @@ -409,7 +446,7 @@ func Test_editRun(t *testing.T) { { name: "non-interactive multiple issues with fetch failures", input: &EditOptions{ - SelectorArgs: []string{"123", "9999"}, + IssueNumbers: []int{123, 9999}, Interactive: false, Editable: prShared.Editable{ Assignees: prShared.EditableSlice{ @@ -454,7 +491,7 @@ func Test_editRun(t *testing.T) { { name: "non-interactive multiple issues with update failures", input: &EditOptions{ - SelectorArgs: []string{"123", "456"}, + IssueNumbers: []int{123, 456}, Interactive: false, Editable: prShared.Editable{ Assignees: prShared.EditableSlice{ @@ -524,7 +561,7 @@ func Test_editRun(t *testing.T) { { name: "interactive", input: &EditOptions{ - SelectorArgs: []string{"123"}, + IssueNumbers: []int{123}, Interactive: true, FieldsToEditSurvey: func(p prShared.EditPrompter, eo *prShared.Editable) error { eo.Title.Edited = true @@ -753,3 +790,88 @@ func mockProjectV2ItemUpdate(t *testing.T, reg *httpmock.Registry) { func(inputs map[string]interface{}) {}), ) } + +// TODO projectsV1Deprecation +// Remove this test. +func TestProjectsV1Deprecation(t *testing.T) { + t.Run("when projects v1 is supported, is included in query", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + reg.Register( + httpmock.GraphQL(`projectCards`), + // Simulate a GraphQL error to early exit the test. + httpmock.StatusStringResponse(500, ""), + ) + + _, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + // Ignore the error because we have no way to really stub it without + // fully stubbing a GQL error structure in the request body. + _ = editRun(&EditOptions{ + IO: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + }, + Detector: &fd.EnabledDetectorMock{}, + + IssueNumbers: []int{123}, + Editable: prShared.Editable{ + Projects: prShared.EditableProjects{ + EditableSlice: prShared.EditableSlice{ + Add: []string{"Test Project"}, + Edited: true, + }, + }, + }, + }) + + // Verify that our request contained projectCards + reg.Verify(t) + }) + + t.Run("when projects v1 is not supported, is not included in query", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + reg.Exclude(t, httpmock.GraphQL(`projectCards`)) + + reg.Register( + httpmock.GraphQL(`.*`), + // Simulate a GraphQL error to early exit the test. + httpmock.StatusStringResponse(500, ""), + ) + + _, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + // Ignore the error because we're not really interested in it. + _ = editRun(&EditOptions{ + IO: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + }, + Detector: &fd.DisabledDetectorMock{}, + + IssueNumbers: []int{123}, + Editable: prShared.Editable{ + Projects: prShared.EditableProjects{ + EditableSlice: prShared.EditableSlice{ + Add: []string{"Test Project"}, + Edited: true, + }, + }, + }, + }) + + // Verify that our request contained projectCards + reg.Verify(t) + }) +} diff --git a/pkg/cmd/issue/lock/lock.go b/pkg/cmd/issue/lock/lock.go index 4e0dac05815..2f332d21dd0 100644 --- a/pkg/cmd/issue/lock/lock.go +++ b/pkg/cmd/issue/lock/lock.go @@ -19,6 +19,7 @@ import ( "github.com/cli/cli/v2/api" "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" + "github.com/cli/cli/v2/pkg/cmd/issue/shared" issueShared "github.com/cli/cli/v2/pkg/cmd/issue/shared" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/iostreams" @@ -99,20 +100,33 @@ type LockOptions struct { ParentCmd string Reason string - SelectorArg string + IssueNumber int Interactive bool } -func (opts *LockOptions) setCommonOptions(f *cmdutil.Factory, args []string) { +func (opts *LockOptions) setCommonOptions(f *cmdutil.Factory, args []string) error { opts.IO = f.IOStreams opts.HttpClient = f.HttpClient opts.Config = f.Config - // support `-R, --repo` override - opts.BaseRepo = f.BaseRepo + issueNumber, baseRepo, err := shared.ParseIssueFromArg(args[0]) + if err != nil { + return err + } - opts.SelectorArg = args[0] + // If the args provided the base repo then use that directly. + if baseRepo, present := baseRepo.Value(); present { + opts.BaseRepo = func() (ghrepo.Interface, error) { + return baseRepo, nil + } + } else { + // support `-R, --repo` override + opts.BaseRepo = f.BaseRepo + } + opts.IssueNumber = issueNumber + + return nil } func NewCmdLock(f *cmdutil.Factory, parentName string, runF func(string, *LockOptions) error) *cobra.Command { @@ -129,7 +143,9 @@ func NewCmdLock(f *cmdutil.Factory, parentName string, runF func(string, *LockOp Short: short, Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - opts.setCommonOptions(f, args) + if err := opts.setCommonOptions(f, args); err != nil { + return err + } reasonProvided := cmd.Flags().Changed("reason") if reasonProvided { @@ -172,7 +188,9 @@ func NewCmdUnlock(f *cmdutil.Factory, parentName string, runF func(string, *Lock Short: short, Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - opts.setCommonOptions(f, args) + if err := opts.setCommonOptions(f, args); err != nil { + return err + } if runF != nil { return runF(Unlock, opts) @@ -214,13 +232,18 @@ func lockRun(state string, opts *LockOptions) error { return err } - issuePr, baseRepo, err := issueShared.IssueFromArgWithFields(httpClient, opts.BaseRepo, opts.SelectorArg, fields()) - - parent := alias[opts.ParentCmd] + baseRepo, err := opts.BaseRepo() + if err != nil { + return err + } + issuePr, err := issueShared.FindIssueOrPR(httpClient, baseRepo, opts.IssueNumber, fields()) if err != nil { return err - } else if parent.Typename != issuePr.Typename { + } + + parent := alias[opts.ParentCmd] + if parent.Typename != issuePr.Typename { currentType := alias[parent.Typename] correctType := alias[issuePr.Typename] diff --git a/pkg/cmd/issue/lock/lock_test.go b/pkg/cmd/issue/lock/lock_test.go index f6dcb746dd6..1ca320f3549 100644 --- a/pkg/cmd/issue/lock/lock_test.go +++ b/pkg/cmd/issue/lock/lock_test.go @@ -30,7 +30,7 @@ func Test_NewCmdLock(t *testing.T) { args: "--reason off_topic 451", want: LockOptions{ Reason: "off_topic", - SelectorArg: "451", + IssueNumber: 451, }, }, { @@ -41,9 +41,36 @@ func Test_NewCmdLock(t *testing.T) { name: "no flags", args: "451", want: LockOptions{ - SelectorArg: "451", + IssueNumber: 451, }, }, + { + name: "issue number argument", + args: "451 --repo owner/repo", + want: LockOptions{ + IssueNumber: 451, + }, + }, + { + name: "argument is hash prefixed number", + // Escaping is required here to avoid what I think is shellex treating it as a comment. + args: "\\#451 --repo owner/repo", + want: LockOptions{ + IssueNumber: 451, + }, + }, + { + name: "argument is a URL", + args: "https://github.com/cli/cli/issues/451", + want: LockOptions{ + IssueNumber: 451, + }, + }, + { + name: "argument cannot be parsed to an issue", + args: "unparseable", + wantErr: "invalid issue format: \"unparseable\"", + }, { name: "bad reason", args: "--reason bad 451", @@ -60,7 +87,7 @@ func Test_NewCmdLock(t *testing.T) { args: "451", tty: true, want: LockOptions{ - SelectorArg: "451", + IssueNumber: 451, Interactive: true, }, }, @@ -99,7 +126,7 @@ func Test_NewCmdLock(t *testing.T) { } assert.Equal(t, tt.want.Reason, opts.Reason) - assert.Equal(t, tt.want.SelectorArg, opts.SelectorArg) + assert.Equal(t, tt.want.IssueNumber, opts.IssueNumber) assert.Equal(t, tt.want.Interactive, opts.Interactive) }) } @@ -121,9 +148,36 @@ func Test_NewCmdUnlock(t *testing.T) { name: "no flags", args: "451", want: LockOptions{ - SelectorArg: "451", + IssueNumber: 451, }, }, + { + name: "issue number argument", + args: "451 --repo owner/repo", + want: LockOptions{ + IssueNumber: 451, + }, + }, + { + name: "argument is hash prefixed number", + // Escaping is required here to avoid what I think is shellex treating it as a comment. + args: "\\#451 --repo owner/repo", + want: LockOptions{ + IssueNumber: 451, + }, + }, + { + name: "argument is a URL", + args: "https://github.com/cli/cli/issues/451", + want: LockOptions{ + IssueNumber: 451, + }, + }, + { + name: "argument cannot be parsed to an issue", + args: "unparseable", + wantErr: "invalid issue format: \"unparseable\"", + }, } for _, tt := range cases { @@ -158,7 +212,7 @@ func Test_NewCmdUnlock(t *testing.T) { assert.NoError(t, err) } - assert.Equal(t, tt.want.SelectorArg, opts.SelectorArg) + assert.Equal(t, tt.want.IssueNumber, opts.IssueNumber) }) } } @@ -179,7 +233,7 @@ func Test_runLock(t *testing.T) { name: "lock issue nontty", state: Lock, opts: LockOptions{ - SelectorArg: "451", + IssueNumber: 451, ParentCmd: "issue", }, httpStubs: func(t *testing.T, reg *httpmock.Registry) { @@ -203,7 +257,7 @@ func Test_runLock(t *testing.T) { tty: true, opts: LockOptions{ Interactive: true, - SelectorArg: "451", + IssueNumber: 451, ParentCmd: "issue", }, state: Lock, @@ -241,7 +295,7 @@ func Test_runLock(t *testing.T) { tty: true, state: Lock, opts: LockOptions{ - SelectorArg: "451", + IssueNumber: 451, ParentCmd: "issue", Reason: "off_topic", }, @@ -268,7 +322,7 @@ func Test_runLock(t *testing.T) { tty: true, state: Unlock, opts: LockOptions{ - SelectorArg: "451", + IssueNumber: 451, ParentCmd: "issue", }, httpStubs: func(t *testing.T, reg *httpmock.Registry) { @@ -294,7 +348,7 @@ func Test_runLock(t *testing.T) { name: "unlock issue nontty", state: Unlock, opts: LockOptions{ - SelectorArg: "451", + IssueNumber: 451, ParentCmd: "issue", }, httpStubs: func(t *testing.T, reg *httpmock.Registry) { @@ -319,7 +373,7 @@ func Test_runLock(t *testing.T) { name: "lock issue with explicit reason nontty", state: Lock, opts: LockOptions{ - SelectorArg: "451", + IssueNumber: 451, ParentCmd: "issue", Reason: "off_topic", }, @@ -344,7 +398,7 @@ func Test_runLock(t *testing.T) { name: "relock issue tty", state: Lock, opts: LockOptions{ - SelectorArg: "451", + IssueNumber: 451, ParentCmd: "issue", Reason: "off_topic", }, @@ -388,7 +442,7 @@ func Test_runLock(t *testing.T) { name: "relock issue nontty", state: Lock, opts: LockOptions{ - SelectorArg: "451", + IssueNumber: 451, ParentCmd: "issue", Reason: "off_topic", }, @@ -409,7 +463,7 @@ func Test_runLock(t *testing.T) { name: "lock pr nontty", state: Lock, opts: LockOptions{ - SelectorArg: "451", + IssueNumber: 451, ParentCmd: "pr", }, httpStubs: func(t *testing.T, reg *httpmock.Registry) { @@ -433,7 +487,7 @@ func Test_runLock(t *testing.T) { tty: true, opts: LockOptions{ Interactive: true, - SelectorArg: "451", + IssueNumber: 451, ParentCmd: "pr", }, state: Lock, @@ -469,7 +523,7 @@ func Test_runLock(t *testing.T) { tty: true, state: Lock, opts: LockOptions{ - SelectorArg: "451", + IssueNumber: 451, ParentCmd: "pr", Reason: "off_topic", }, @@ -495,7 +549,7 @@ func Test_runLock(t *testing.T) { name: "lock pr with explicit nontty", state: Lock, opts: LockOptions{ - SelectorArg: "451", + IssueNumber: 451, ParentCmd: "pr", Reason: "off_topic", }, @@ -520,7 +574,7 @@ func Test_runLock(t *testing.T) { name: "unlock pr tty", state: Unlock, opts: LockOptions{ - SelectorArg: "451", + IssueNumber: 451, ParentCmd: "pr", }, httpStubs: func(t *testing.T, reg *httpmock.Registry) { @@ -546,7 +600,7 @@ func Test_runLock(t *testing.T) { tty: true, state: Unlock, opts: LockOptions{ - SelectorArg: "451", + IssueNumber: 451, ParentCmd: "pr", }, httpStubs: func(t *testing.T, reg *httpmock.Registry) { @@ -572,7 +626,7 @@ func Test_runLock(t *testing.T) { name: "relock pr tty", state: Lock, opts: LockOptions{ - SelectorArg: "451", + IssueNumber: 451, ParentCmd: "pr", Reason: "off_topic", }, @@ -616,7 +670,7 @@ func Test_runLock(t *testing.T) { name: "relock pr nontty", state: Lock, opts: LockOptions{ - SelectorArg: "451", + IssueNumber: 451, ParentCmd: "pr", Reason: "off_topic", }, diff --git a/pkg/cmd/issue/pin/pin.go b/pkg/cmd/issue/pin/pin.go index dfb11a8811c..290bec50797 100644 --- a/pkg/cmd/issue/pin/pin.go +++ b/pkg/cmd/issue/pin/pin.go @@ -20,7 +20,7 @@ type PinOptions struct { Config func() (gh.Config, error) IO *iostreams.IOStreams BaseRepo func() (ghrepo.Interface, error) - SelectorArg string + IssueNumber int } func NewCmdPin(f *cmdutil.Factory, runF func(*PinOptions) error) *cobra.Command { @@ -51,8 +51,22 @@ func NewCmdPin(f *cmdutil.Factory, runF func(*PinOptions) error) *cobra.Command `), Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - opts.BaseRepo = f.BaseRepo - opts.SelectorArg = args[0] + issueNumber, baseRepo, err := shared.ParseIssueFromArg(args[0]) + if err != nil { + return err + } + + // If the args provided the base repo then use that directly. + if baseRepo, present := baseRepo.Value(); present { + opts.BaseRepo = func() (ghrepo.Interface, error) { + return baseRepo, nil + } + } else { + // support `-R, --repo` override + opts.BaseRepo = f.BaseRepo + } + + opts.IssueNumber = issueNumber if runF != nil { return runF(opts) @@ -73,7 +87,12 @@ func pinRun(opts *PinOptions) error { return err } - issue, baseRepo, err := shared.IssueFromArgWithFields(httpClient, opts.BaseRepo, opts.SelectorArg, []string{"id", "number", "title", "isPinned"}) + baseRepo, err := opts.BaseRepo() + if err != nil { + return err + } + + issue, err := shared.FindIssueOrPR(httpClient, baseRepo, opts.IssueNumber, []string{"id", "number", "title", "isPinned"}) if err != nil { return err } diff --git a/pkg/cmd/issue/pin/pin_test.go b/pkg/cmd/issue/pin/pin_test.go index d4979a30dcb..67b767b32b8 100644 --- a/pkg/cmd/issue/pin/pin_test.go +++ b/pkg/cmd/issue/pin/pin_test.go @@ -1,80 +1,21 @@ package pin import ( - "bytes" "net/http" "testing" "github.com/cli/cli/v2/internal/config" "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" - "github.com/cli/cli/v2/pkg/cmdutil" + "github.com/cli/cli/v2/pkg/cmd/issue/argparsetest" "github.com/cli/cli/v2/pkg/httpmock" "github.com/cli/cli/v2/pkg/iostreams" - "github.com/google/shlex" "github.com/stretchr/testify/assert" ) func TestNewCmdPin(t *testing.T) { - tests := []struct { - name string - input string - output PinOptions - wantErr bool - errMsg string - }{ - { - name: "no argument", - input: "", - wantErr: true, - errMsg: "accepts 1 arg(s), received 0", - }, - { - name: "issue number", - input: "6", - output: PinOptions{ - SelectorArg: "6", - }, - }, - { - name: "issue url", - input: "https://github.com/cli/cli/6", - output: PinOptions{ - SelectorArg: "https://github.com/cli/cli/6", - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ios, _, _, _ := iostreams.Test() - ios.SetStdinTTY(true) - ios.SetStdoutTTY(true) - f := &cmdutil.Factory{ - IOStreams: ios, - } - argv, err := shlex.Split(tt.input) - assert.NoError(t, err) - var gotOpts *PinOptions - cmd := NewCmdPin(f, func(opts *PinOptions) error { - gotOpts = opts - return nil - }) - cmd.SetArgs(argv) - cmd.SetIn(&bytes.Buffer{}) - cmd.SetOut(&bytes.Buffer{}) - cmd.SetErr(&bytes.Buffer{}) - - _, err = cmd.ExecuteC() - if tt.wantErr { - assert.Error(t, err) - assert.Equal(t, tt.errMsg, err.Error()) - return - } - - assert.NoError(t, err) - assert.Equal(t, tt.output.SelectorArg, gotOpts.SelectorArg) - }) - } + // Test shared parsing of issue number / URL. + argparsetest.TestArgParsing(t, NewCmdPin) } func TestPinRun(t *testing.T) { @@ -89,7 +30,7 @@ func TestPinRun(t *testing.T) { { name: "pin issue", tty: true, - opts: &PinOptions{SelectorArg: "20"}, + opts: &PinOptions{IssueNumber: 20}, httpStubs: func(reg *httpmock.Registry) { reg.Register( httpmock.GraphQL(`query IssueByNumber\b`), @@ -113,7 +54,7 @@ func TestPinRun(t *testing.T) { { name: "issue already pinned", tty: true, - opts: &PinOptions{SelectorArg: "20"}, + opts: &PinOptions{IssueNumber: 20}, httpStubs: func(reg *httpmock.Registry) { reg.Register( httpmock.GraphQL(`query IssueByNumber\b`), diff --git a/pkg/cmd/issue/reopen/reopen.go b/pkg/cmd/issue/reopen/reopen.go index 92f18a7d979..f01a8eafcac 100644 --- a/pkg/cmd/issue/reopen/reopen.go +++ b/pkg/cmd/issue/reopen/reopen.go @@ -21,7 +21,7 @@ type ReopenOptions struct { IO *iostreams.IOStreams BaseRepo func() (ghrepo.Interface, error) - SelectorArg string + IssueNumber int Comment string } @@ -37,13 +37,23 @@ func NewCmdReopen(f *cmdutil.Factory, runF func(*ReopenOptions) error) *cobra.Co Short: "Reopen issue", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - // support `-R, --repo` override - opts.BaseRepo = f.BaseRepo + issueNumber, baseRepo, err := shared.ParseIssueFromArg(args[0]) + if err != nil { + return err + } - if len(args) > 0 { - opts.SelectorArg = args[0] + // If the args provided the base repo then use that directly. + if baseRepo, present := baseRepo.Value(); present { + opts.BaseRepo = func() (ghrepo.Interface, error) { + return baseRepo, nil + } + } else { + // support `-R, --repo` override + opts.BaseRepo = f.BaseRepo } + opts.IssueNumber = issueNumber + if runF != nil { return runF(opts) } @@ -64,7 +74,12 @@ func reopenRun(opts *ReopenOptions) error { return err } - issue, baseRepo, err := shared.IssueFromArgWithFields(httpClient, opts.BaseRepo, opts.SelectorArg, []string{"id", "number", "title", "state"}) + baseRepo, err := opts.BaseRepo() + if err != nil { + return err + } + + issue, err := shared.FindIssueOrPR(httpClient, baseRepo, opts.IssueNumber, []string{"id", "number", "title", "state"}) if err != nil { return err } diff --git a/pkg/cmd/issue/reopen/reopen_test.go b/pkg/cmd/issue/reopen/reopen_test.go index 4b8b33ee1a6..f7c8cb95a32 100644 --- a/pkg/cmd/issue/reopen/reopen_test.go +++ b/pkg/cmd/issue/reopen/reopen_test.go @@ -10,6 +10,7 @@ import ( "github.com/cli/cli/v2/internal/config" "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" + "github.com/cli/cli/v2/pkg/cmd/issue/argparsetest" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/httpmock" "github.com/cli/cli/v2/pkg/iostreams" @@ -18,6 +19,11 @@ import ( "github.com/stretchr/testify/assert" ) +func TestNewCmdReopen(t *testing.T) { + // Test shared parsing of issue number / URL. + argparsetest.TestArgParsing(t, NewCmdReopen) +} + func runCommand(rt http.RoundTripper, isTTY bool, cli string) (*test.CmdOut, error) { ios, _, stdout, stderr := iostreams.Test() ios.SetStdoutTTY(isTTY) diff --git a/pkg/cmd/issue/shared/lookup.go b/pkg/cmd/issue/shared/lookup.go index be79f9a73e6..5c477363bcc 100644 --- a/pkg/cmd/issue/shared/lookup.go +++ b/pkg/cmd/issue/shared/lookup.go @@ -13,136 +13,129 @@ import ( "github.com/cli/cli/v2/api" fd "github.com/cli/cli/v2/internal/featuredetection" "github.com/cli/cli/v2/internal/ghrepo" + o "github.com/cli/cli/v2/pkg/option" "github.com/cli/cli/v2/pkg/set" "golang.org/x/sync/errgroup" ) -// IssueFromArgWithFields loads an issue or pull request with the specified fields. If some of the fields -// could not be fetched by GraphQL, this returns a non-nil issue and a *PartialLoadError. -func IssueFromArgWithFields(httpClient *http.Client, baseRepoFn func() (ghrepo.Interface, error), arg string, fields []string) (*api.Issue, ghrepo.Interface, error) { - issueNumber, baseRepo, err := IssueNumberAndRepoFromArg(arg) - if err != nil { - return nil, nil, err - } - - if baseRepo == nil { - var err error - if baseRepo, err = baseRepoFn(); err != nil { - return nil, nil, err - } - } - - issue, err := findIssueOrPR(httpClient, baseRepo, issueNumber, fields) - return issue, baseRepo, err -} +var issueURLRE = regexp.MustCompile(`^/([^/]+)/([^/]+)/(?:issues|pull)/(\d+)`) -// IssuesFromArgsWithFields loads 1 or more issues or pull requests with the specified fields. If some of the fields -// could not be fetched by GraphQL, this returns non-nil issues and a *PartialLoadError. -func IssuesFromArgsWithFields(httpClient *http.Client, baseRepoFn func() (ghrepo.Interface, error), args []string, fields []string) ([]*api.Issue, ghrepo.Interface, error) { - var issuesRepo ghrepo.Interface - issueNumbers := make([]int, 0, len(args)) +func ParseIssuesFromArgs(args []string) ([]int, o.Option[ghrepo.Interface], error) { + var repo o.Option[ghrepo.Interface] + issueNumbers := make([]int, len(args)) - for _, arg := range args { - issueNumber, baseRepo, err := IssueNumberAndRepoFromArg(arg) + for i, arg := range args { + // For each argument, parse the issue number and an optional repo + issueNumber, issueRepo, err := ParseIssueFromArg(arg) if err != nil { - return nil, nil, err + return nil, o.None[ghrepo.Interface](), err } - issueNumbers = append(issueNumbers, issueNumber) - if baseRepo == nil { - var err error - if baseRepo, err = baseRepoFn(); err != nil { - return nil, nil, err - } + // if this is our first issue repo found, then we need to set it + if repo.IsNone() { + repo = issueRepo } - if issuesRepo == nil { - issuesRepo = baseRepo - continue + // if there is an issue repo returned, then we need to check if it is the same as the previous one + if issueRepo.IsSome() && repo.IsSome() { + // Unwraps are safe because we've checked for presence above + if !ghrepo.IsSame(repo.Unwrap(), issueRepo.Unwrap()) { + return nil, o.None[ghrepo.Interface](), fmt.Errorf( + "multiple issues must be in same repo: found %q, expected %q", + ghrepo.FullName(issueRepo.Unwrap()), + ghrepo.FullName(repo.Unwrap()), + ) + } } - if !ghrepo.IsSame(issuesRepo, baseRepo) { - return nil, nil, fmt.Errorf( - "multiple issues must be in same repo: found %q, expected %q", - ghrepo.FullName(baseRepo), - ghrepo.FullName(issuesRepo), - ) - } + // add the issue number to the list + issueNumbers[i] = issueNumber } - issuesChan := make(chan *api.Issue, len(args)) - g := errgroup.Group{} - for _, num := range issueNumbers { - issueNumber := num - g.Go(func() error { - issue, err := findIssueOrPR(httpClient, issuesRepo, issueNumber, fields) - if err != nil { - return err - } + return issueNumbers, repo, nil +} - issuesChan <- issue - return nil - }) +func ParseIssueFromArg(arg string) (int, o.Option[ghrepo.Interface], error) { + issueLocator := tryParseIssueFromURL(arg) + if issueLocator, present := issueLocator.Value(); present { + return issueLocator.issueNumber, o.Some(issueLocator.repo), nil } - err := g.Wait() - close(issuesChan) - + issueNumber, err := strconv.Atoi(strings.TrimPrefix(arg, "#")) if err != nil { - return nil, nil, err + return 0, o.None[ghrepo.Interface](), fmt.Errorf("invalid issue format: %q", arg) } - issues := make([]*api.Issue, 0, len(args)) - for issue := range issuesChan { - issues = append(issues, issue) - } - - return issues, issuesRepo, nil + return issueNumber, o.None[ghrepo.Interface](), nil } -var issueURLRE = regexp.MustCompile(`^/([^/]+)/([^/]+)/(?:issues|pull)/(\d+)`) +type issueLocator struct { + issueNumber int + repo ghrepo.Interface +} -func issueMetadataFromURL(s string) (int, ghrepo.Interface) { - u, err := url.Parse(s) +// tryParseIssueFromURL tries to parse an issue number and repo from a URL. +func tryParseIssueFromURL(maybeURL string) o.Option[issueLocator] { + u, err := url.Parse(maybeURL) if err != nil { - return 0, nil + return o.None[issueLocator]() } if u.Scheme != "https" && u.Scheme != "http" { - return 0, nil + return o.None[issueLocator]() } m := issueURLRE.FindStringSubmatch(u.Path) if m == nil { - return 0, nil + return o.None[issueLocator]() } repo := ghrepo.NewWithHost(m[1], m[2], u.Hostname()) issueNumber, _ := strconv.Atoi(m[3]) - return issueNumber, repo + return o.Some(issueLocator{ + issueNumber: issueNumber, + repo: repo, + }) } -// Returns the issue number and repo if the issue URL is provided. -// If only the issue number is provided, returns the number and nil repo. -func IssueNumberAndRepoFromArg(arg string) (int, ghrepo.Interface, error) { - issueNumber, baseRepo := issueMetadataFromURL(arg) +type PartialLoadError struct { + error +} - if issueNumber == 0 { - var err error - issueNumber, err = strconv.Atoi(strings.TrimPrefix(arg, "#")) - if err != nil { - return 0, nil, fmt.Errorf("invalid issue format: %q", arg) - } +// FindIssuesOrPRs loads 1 or more issues or pull requests with the specified fields. If some of the fields +// could not be fetched by GraphQL, this returns non-nil issues and a *PartialLoadError. +func FindIssuesOrPRs(httpClient *http.Client, repo ghrepo.Interface, issueNumbers []int, fields []string) ([]*api.Issue, error) { + issuesChan := make(chan *api.Issue, len(issueNumbers)) + g := errgroup.Group{} + for _, num := range issueNumbers { + issueNumber := num + g.Go(func() error { + issue, err := FindIssueOrPR(httpClient, repo, issueNumber, fields) + if err != nil { + return err + } + + issuesChan <- issue + return nil + }) } - return issueNumber, baseRepo, nil -} + err := g.Wait() + close(issuesChan) -type PartialLoadError struct { - error + if err != nil { + return nil, err + } + + issues := make([]*api.Issue, 0, len(issueNumbers)) + for issue := range issuesChan { + issues = append(issues, issue) + } + + return issues, nil } -func findIssueOrPR(httpClient *http.Client, repo ghrepo.Interface, number int, fields []string) (*api.Issue, error) { +func FindIssueOrPR(httpClient *http.Client, repo ghrepo.Interface, number int, fields []string) (*api.Issue, error) { fieldSet := set.NewStringSet() fieldSet.AddValues(fields) if fieldSet.Contains("stateReason") { diff --git a/pkg/cmd/issue/shared/lookup_test.go b/pkg/cmd/issue/shared/lookup_test.go index 44f496de44a..f921ca49b71 100644 --- a/pkg/cmd/issue/shared/lookup_test.go +++ b/pkg/cmd/issue/shared/lookup_test.go @@ -2,240 +2,94 @@ package shared import ( "net/http" - "strings" "testing" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/pkg/httpmock" + o "github.com/cli/cli/v2/pkg/option" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func TestIssueFromArgWithFields(t *testing.T) { - type args struct { - baseRepoFn func() (ghrepo.Interface, error) - selector string - } +func TestParseIssuesFromArgs(t *testing.T) { tests := []struct { - name string - args args - httpStub func(*httpmock.Registry) - wantIssue int - wantRepo string - wantProjects string - wantErr bool + behavior string + args []string + expectedIssueNumbers []int + expectedRepo o.Option[ghrepo.Interface] + expectedErr bool }{ { - name: "number argument", - args: args{ - selector: "13", - baseRepoFn: func() (ghrepo.Interface, error) { - return ghrepo.FromFullName("OWNER/REPO") - }, - }, - httpStub: func(r *httpmock.Registry) { - r.Register( - httpmock.GraphQL(`query IssueByNumber\b`), - httpmock.StringResponse(`{"data":{"repository":{ - "hasIssuesEnabled": true, - "issue":{"number":13} - }}}`)) - }, - wantIssue: 13, - wantRepo: "https://github.com/OWNER/REPO", + behavior: "when given issue numbers, returns them with no repo", + args: []string{"1", "2"}, + expectedIssueNumbers: []int{1, 2}, + expectedRepo: o.None[ghrepo.Interface](), }, { - name: "number with hash argument", - args: args{ - selector: "#13", - baseRepoFn: func() (ghrepo.Interface, error) { - return ghrepo.FromFullName("OWNER/REPO") - }, - }, - httpStub: func(r *httpmock.Registry) { - r.Register( - httpmock.GraphQL(`query IssueByNumber\b`), - httpmock.StringResponse(`{"data":{"repository":{ - "hasIssuesEnabled": true, - "issue":{"number":13} - }}}`)) - }, - wantIssue: 13, - wantRepo: "https://github.com/OWNER/REPO", + behavior: "when given # prefixed issue numbers, returns them with no repo", + args: []string{"#1", "#2"}, + expectedIssueNumbers: []int{1, 2}, + expectedRepo: o.None[ghrepo.Interface](), }, { - name: "URL argument", - args: args{ - selector: "https://example.org/OWNER/REPO/issues/13#comment-123", - baseRepoFn: nil, - }, - httpStub: func(r *httpmock.Registry) { - r.Register( - httpmock.GraphQL(`query IssueByNumber\b`), - httpmock.StringResponse(`{"data":{"repository":{ - "hasIssuesEnabled": true, - "issue":{"number":13} - }}}`)) + behavior: "when given URLs, returns them with the repo", + args: []string{ + "https://github.com/OWNER/REPO/issues/1", + "https://github.com/OWNER/REPO/issues/2", }, - wantIssue: 13, - wantRepo: "https://example.org/OWNER/REPO", + expectedIssueNumbers: []int{1, 2}, + expectedRepo: o.Some(ghrepo.New("OWNER", "REPO")), }, { - name: "PR URL argument", - args: args{ - selector: "https://example.org/OWNER/REPO/pull/13#comment-123", - baseRepoFn: nil, - }, - httpStub: func(r *httpmock.Registry) { - r.Register( - httpmock.GraphQL(`query IssueByNumber\b`), - httpmock.StringResponse(`{"data":{"repository":{ - "hasIssuesEnabled": true, - "issue":{"number":13} - }}}`)) + behavior: "when given URLs in different repos, errors", + args: []string{ + "https://github.com/OWNER/REPO/issues/1", + "https://github.com/OWNER/OTHERREPO/issues/2", }, - wantIssue: 13, - wantRepo: "https://example.org/OWNER/REPO", + expectedErr: true, }, { - name: "project cards permission issue", - args: args{ - selector: "https://example.org/OWNER/REPO/issues/13", - baseRepoFn: nil, - }, - httpStub: func(r *httpmock.Registry) { - r.Register( - httpmock.GraphQL(`query IssueByNumber\b`), - httpmock.StringResponse(` - { - "data": { - "repository": { - "hasIssuesEnabled": true, - "issue": { - "number": 13, - "projectCards": { - "nodes": [ - null, - { - "project": {"name": "myproject"}, - "column": {"name": "To Do"} - }, - null, - { - "project": {"name": "other project"}, - "column": null - } - ] - } - } - } - }, - "errors": [ - { - "type": "FORBIDDEN", - "message": "Resource not accessible by integration", - "path": ["repository", "issue", "projectCards", "nodes", 0] - }, - { - "type": "FORBIDDEN", - "message": "Resource not accessible by integration", - "path": ["repository", "issue", "projectCards", "nodes", 2] - } - ] - }`)) - }, - wantErr: true, - wantIssue: 13, - wantProjects: "myproject, other project", - wantRepo: "https://example.org/OWNER/REPO", + behavior: "when given an unparseable argument, errors", + args: []string{"://"}, + expectedErr: true, }, { - name: "projects permission issue", - args: args{ - selector: "https://example.org/OWNER/REPO/issues/13", - baseRepoFn: nil, - }, - httpStub: func(r *httpmock.Registry) { - r.Register( - httpmock.GraphQL(`query IssueByNumber\b`), - httpmock.StringResponse(` - { - "data": { - "repository": { - "hasIssuesEnabled": true, - "issue": { - "number": 13, - "projectCards": { - "nodes": null, - "totalCount": 0 - } - } - } - }, - "errors": [ - { - "type": "FORBIDDEN", - "message": "Resource not accessible by integration", - "path": ["repository", "issue", "projectCards", "nodes"] - } - ] - }`)) - }, - wantErr: true, - wantIssue: 13, - wantProjects: "", - wantRepo: "https://example.org/OWNER/REPO", + behavior: "when given a URL that isn't an issue or PR url, errors", + args: []string{"https://github.com"}, + expectedErr: true, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - reg := &httpmock.Registry{} - if tt.httpStub != nil { - tt.httpStub(reg) - } - httpClient := &http.Client{Transport: reg} - issue, repo, err := IssueFromArgWithFields(httpClient, tt.args.baseRepoFn, tt.args.selector, []string{"number"}) - if (err != nil) != tt.wantErr { - t.Errorf("IssueFromArgWithFields() error = %v, wantErr %v", err, tt.wantErr) - if issue == nil { - return - } - } - if issue.Number != tt.wantIssue { - t.Errorf("want issue #%d, got #%d", tt.wantIssue, issue.Number) - } - if gotProjects := strings.Join(issue.ProjectCards.ProjectNames(), ", "); gotProjects != tt.wantProjects { - t.Errorf("want projects %q, got %q", tt.wantProjects, gotProjects) - } - repoURL := ghrepo.GenerateRepoURL(repo, "") - if repoURL != tt.wantRepo { - t.Errorf("want repo %s, got %s", tt.wantRepo, repoURL) + + for _, tc := range tests { + t.Run(tc.behavior, func(t *testing.T) { + issueNumbers, repo, err := ParseIssuesFromArgs(tc.args) + + if tc.expectedErr { + require.Error(t, err) + return } + + require.NoError(t, err) + assert.Equal(t, tc.expectedIssueNumbers, issueNumbers) + assert.Equal(t, tc.expectedRepo, repo) }) } + } -func TestIssuesFromArgsWithFields(t *testing.T) { - type args struct { - baseRepoFn func() (ghrepo.Interface, error) - selectors []string - } +func TestFindIssuesOrPRs(t *testing.T) { tests := []struct { - name string - args args - httpStub func(*httpmock.Registry) - wantIssues []int - wantRepo string - wantErr bool - wantErrMsg string + name string + issueNumbers []int + baseRepo ghrepo.Interface + httpStub func(*httpmock.Registry) + wantIssueNumbers []int + wantErr bool }{ { - name: "multiple repos", - args: args{ - selectors: []string{"1", "https://github.com/OWNER/OTHERREPO/issues/2"}, - baseRepoFn: func() (ghrepo.Interface, error) { - return ghrepo.New("OWNER", "REPO"), nil - }, - }, + name: "multiple issues", + issueNumbers: []int{1, 2}, + baseRepo: ghrepo.New("OWNER", "REPO"), httpStub: func(r *httpmock.Registry) { r.Register( httpmock.GraphQL(`query IssueByNumber\b`), @@ -248,19 +102,14 @@ func TestIssuesFromArgsWithFields(t *testing.T) { httpmock.StringResponse(`{"data":{"repository":{ "hasIssuesEnabled": true, "issue":{"number":2} - }}}`)) + }}}`)) }, - wantErr: true, - wantErrMsg: "multiple issues must be in same repo", + wantIssueNumbers: []int{1, 2}, }, { - name: "multiple issues", - args: args{ - selectors: []string{"1", "2"}, - baseRepoFn: func() (ghrepo.Interface, error) { - return ghrepo.New("OWNER", "REPO"), nil - }, - }, + name: "any find error results in total error", + issueNumbers: []int{1, 2}, + baseRepo: ghrepo.New("OWNER", "REPO"), httpStub: func(r *httpmock.Registry) { r.Register( httpmock.GraphQL(`query IssueByNumber\b`), @@ -270,48 +119,33 @@ func TestIssuesFromArgsWithFields(t *testing.T) { }}}`)) r.Register( httpmock.GraphQL(`query IssueByNumber\b`), - httpmock.StringResponse(`{"data":{"repository":{ - "hasIssuesEnabled": true, - "issue":{"number":2} - }}}`)) + httpmock.StatusStringResponse(500, "internal server error")) }, - wantIssues: []int{1, 2}, - wantRepo: "https://github.com/OWNER/REPO", + wantErr: true, }, } for _, tt := range tests { - if !tt.wantErr && len(tt.args.selectors) != len(tt.wantIssues) { - t.Fatal("number of selectors and issues not equal") - } t.Run(tt.name, func(t *testing.T) { reg := &httpmock.Registry{} if tt.httpStub != nil { tt.httpStub(reg) } httpClient := &http.Client{Transport: reg} - issues, repo, err := IssuesFromArgsWithFields(httpClient, tt.args.baseRepoFn, tt.args.selectors, []string{"number"}) + issues, err := FindIssuesOrPRs(httpClient, tt.baseRepo, tt.issueNumbers, []string{"number"}) if (err != nil) != tt.wantErr { - t.Errorf("IssuesFromArgsWithFields() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("FindIssuesOrPRs() error = %v, wantErr %v", err, tt.wantErr) if issues == nil { return } } if tt.wantErr { - assert.Error(t, err) - assert.ErrorContains(t, err, tt.wantErrMsg) + require.Error(t, err) return } - assert.NoError(t, err) + + require.NoError(t, err) for i := range issues { - assert.Contains(t, tt.wantIssues, issues[i].Number) - } - if repo != nil { - repoURL := ghrepo.GenerateRepoURL(repo, "") - if repoURL != tt.wantRepo { - t.Errorf("want repo %s, got %s", tt.wantRepo, repoURL) - } - } else if tt.wantRepo != "" { - t.Errorf("want repo %sw, got nil", tt.wantRepo) + assert.Contains(t, tt.wantIssueNumbers, issues[i].Number) } }) } diff --git a/pkg/cmd/issue/transfer/transfer.go b/pkg/cmd/issue/transfer/transfer.go index 140d02b9194..a6dfb9b2319 100644 --- a/pkg/cmd/issue/transfer/transfer.go +++ b/pkg/cmd/issue/transfer/transfer.go @@ -20,7 +20,7 @@ type TransferOptions struct { IO *iostreams.IOStreams BaseRepo func() (ghrepo.Interface, error) - IssueSelector string + IssueNumber int DestRepoSelector string } @@ -36,8 +36,23 @@ func NewCmdTransfer(f *cmdutil.Factory, runF func(*TransferOptions) error) *cobr Short: "Transfer issue to another repository", Args: cmdutil.ExactArgs(2, "issue and destination repository are required"), RunE: func(cmd *cobra.Command, args []string) error { - opts.BaseRepo = f.BaseRepo - opts.IssueSelector = args[0] + issueNumber, baseRepo, err := shared.ParseIssueFromArg(args[0]) + if err != nil { + return err + } + + // If the args provided the base repo then use that directly. + if baseRepo, present := baseRepo.Value(); present { + opts.BaseRepo = func() (ghrepo.Interface, error) { + return baseRepo, nil + } + } else { + // support `-R, --repo` override + opts.BaseRepo = f.BaseRepo + } + + opts.IssueNumber = issueNumber + opts.DestRepoSelector = args[1] if runF != nil { @@ -57,7 +72,12 @@ func transferRun(opts *TransferOptions) error { return err } - issue, baseRepo, err := shared.IssueFromArgWithFields(httpClient, opts.BaseRepo, opts.IssueSelector, []string{"id", "number"}) + baseRepo, err := opts.BaseRepo() + if err != nil { + return err + } + + issue, err := shared.FindIssueOrPR(httpClient, baseRepo, opts.IssueNumber, []string{"id", "number"}) if err != nil { return err } diff --git a/pkg/cmd/issue/transfer/transfer_test.go b/pkg/cmd/issue/transfer/transfer_test.go index eed9c5d85c4..2b12db9442f 100644 --- a/pkg/cmd/issue/transfer/transfer_test.go +++ b/pkg/cmd/issue/transfer/transfer_test.go @@ -15,6 +15,7 @@ import ( "github.com/cli/cli/v2/test" "github.com/google/shlex" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func runCommand(rt http.RoundTripper, cli string) (*test.CmdOut, error) { @@ -57,18 +58,49 @@ func runCommand(rt http.RoundTripper, cli string) (*test.CmdOut, error) { func TestNewCmdTransfer(t *testing.T) { tests := []struct { - name string - cli string - wants TransferOptions - wantErr string + name string + cli string + wants TransferOptions + wantBaseRepo ghrepo.Interface + wantErr bool }{ { - name: "issue name", - cli: "3252 OWNER/REPO", + name: "no argument", + cli: "", + wantErr: true, + }, + { + name: "issue number argument", + cli: "--repo cli/repo 23 OWNER/REPO", + wants: TransferOptions{ + IssueNumber: 23, + DestRepoSelector: "OWNER/REPO", + }, + wantBaseRepo: ghrepo.New("cli", "repo"), + }, + { + name: "argument is hash prefixed number", + // Escaping is required here to avoid what I think is shellex treating it as a comment. + cli: "--repo cli/repo \\#23 OWNER/REPO", + wants: TransferOptions{ + IssueNumber: 23, + DestRepoSelector: "OWNER/REPO", + }, + wantBaseRepo: ghrepo.New("cli", "repo"), + }, + { + name: "argument is a URL", + cli: "https://github.com/cli/cli/issues/23 OWNER/REPO", wants: TransferOptions{ - IssueSelector: "3252", + IssueNumber: 23, DestRepoSelector: "OWNER/REPO", }, + wantBaseRepo: ghrepo.New("cli", "cli"), + }, + { + name: "argument cannot be parsed to an issue", + cli: "unparseable OWNER/REPO", + wantErr: true, }, } @@ -84,15 +116,29 @@ func TestNewCmdTransfer(t *testing.T) { gotOpts = opts return nil }) + cmdutil.EnableRepoOverride(cmd, f) + cmd.SetArgs(argv) cmd.SetIn(&bytes.Buffer{}) cmd.SetOut(&bytes.Buffer{}) cmd.SetErr(&bytes.Buffer{}) _, cErr := cmd.ExecuteC() - assert.NoError(t, cErr) - assert.Equal(t, tt.wants.IssueSelector, gotOpts.IssueSelector) + if tt.wantErr { + require.Error(t, cErr) + return + } + + require.NoError(t, cErr) + assert.Equal(t, tt.wants.IssueNumber, gotOpts.IssueNumber) assert.Equal(t, tt.wants.DestRepoSelector, gotOpts.DestRepoSelector) + actualBaseRepo, err := gotOpts.BaseRepo() + require.NoError(t, err) + assert.True( + t, + ghrepo.IsSame(tt.wantBaseRepo, actualBaseRepo), + "expected base repo %+v, got %+v", tt.wantBaseRepo, actualBaseRepo, + ) }) } } diff --git a/pkg/cmd/issue/unpin/unpin.go b/pkg/cmd/issue/unpin/unpin.go index 3ac28d47cf5..ca22aa82eee 100644 --- a/pkg/cmd/issue/unpin/unpin.go +++ b/pkg/cmd/issue/unpin/unpin.go @@ -16,11 +16,12 @@ import ( ) type UnpinOptions struct { - HttpClient func() (*http.Client, error) - Config func() (gh.Config, error) - IO *iostreams.IOStreams - BaseRepo func() (ghrepo.Interface, error) - SelectorArg string + HttpClient func() (*http.Client, error) + Config func() (gh.Config, error) + IO *iostreams.IOStreams + BaseRepo func() (ghrepo.Interface, error) + + IssueNumber int } func NewCmdUnpin(f *cmdutil.Factory, runF func(*UnpinOptions) error) *cobra.Command { @@ -51,8 +52,22 @@ func NewCmdUnpin(f *cmdutil.Factory, runF func(*UnpinOptions) error) *cobra.Comm `), Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - opts.BaseRepo = f.BaseRepo - opts.SelectorArg = args[0] + issueNumber, baseRepo, err := shared.ParseIssueFromArg(args[0]) + if err != nil { + return err + } + + // If the args provided the base repo then use that directly. + if baseRepo, present := baseRepo.Value(); present { + opts.BaseRepo = func() (ghrepo.Interface, error) { + return baseRepo, nil + } + } else { + // support `-R, --repo` override + opts.BaseRepo = f.BaseRepo + } + + opts.IssueNumber = issueNumber if runF != nil { return runF(opts) @@ -73,7 +88,12 @@ func unpinRun(opts *UnpinOptions) error { return err } - issue, baseRepo, err := shared.IssueFromArgWithFields(httpClient, opts.BaseRepo, opts.SelectorArg, []string{"id", "number", "title", "isPinned"}) + baseRepo, err := opts.BaseRepo() + if err != nil { + return err + } + + issue, err := shared.FindIssueOrPR(httpClient, baseRepo, opts.IssueNumber, []string{"id", "number", "title", "isPinned"}) if err != nil { return err } diff --git a/pkg/cmd/issue/unpin/unpin_test.go b/pkg/cmd/issue/unpin/unpin_test.go index 70a018d946f..3cdf29a748a 100644 --- a/pkg/cmd/issue/unpin/unpin_test.go +++ b/pkg/cmd/issue/unpin/unpin_test.go @@ -1,80 +1,21 @@ package unpin import ( - "bytes" "net/http" "testing" "github.com/cli/cli/v2/internal/config" "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" - "github.com/cli/cli/v2/pkg/cmdutil" + "github.com/cli/cli/v2/pkg/cmd/issue/argparsetest" "github.com/cli/cli/v2/pkg/httpmock" "github.com/cli/cli/v2/pkg/iostreams" - "github.com/google/shlex" "github.com/stretchr/testify/assert" ) -func TestNewCmdPin(t *testing.T) { - tests := []struct { - name string - input string - output UnpinOptions - wantErr bool - errMsg string - }{ - { - name: "no argument", - input: "", - wantErr: true, - errMsg: "accepts 1 arg(s), received 0", - }, - { - name: "issue number", - input: "6", - output: UnpinOptions{ - SelectorArg: "6", - }, - }, - { - name: "issue url", - input: "https://github.com/cli/cli/6", - output: UnpinOptions{ - SelectorArg: "https://github.com/cli/cli/6", - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ios, _, _, _ := iostreams.Test() - ios.SetStdinTTY(true) - ios.SetStdoutTTY(true) - f := &cmdutil.Factory{ - IOStreams: ios, - } - argv, err := shlex.Split(tt.input) - assert.NoError(t, err) - var gotOpts *UnpinOptions - cmd := NewCmdUnpin(f, func(opts *UnpinOptions) error { - gotOpts = opts - return nil - }) - cmd.SetArgs(argv) - cmd.SetIn(&bytes.Buffer{}) - cmd.SetOut(&bytes.Buffer{}) - cmd.SetErr(&bytes.Buffer{}) - - _, err = cmd.ExecuteC() - if tt.wantErr { - assert.Error(t, err) - assert.Equal(t, tt.errMsg, err.Error()) - return - } - - assert.NoError(t, err) - assert.Equal(t, tt.output.SelectorArg, gotOpts.SelectorArg) - }) - } +func TestNewCmdUnpin(t *testing.T) { + // Test shared parsing of issue number / URL. + argparsetest.TestArgParsing(t, NewCmdUnpin) } func TestUnpinRun(t *testing.T) { @@ -89,7 +30,7 @@ func TestUnpinRun(t *testing.T) { { name: "unpin issue", tty: true, - opts: &UnpinOptions{SelectorArg: "20"}, + opts: &UnpinOptions{IssueNumber: 20}, httpStubs: func(reg *httpmock.Registry) { reg.Register( httpmock.GraphQL(`query IssueByNumber\b`), @@ -113,7 +54,7 @@ func TestUnpinRun(t *testing.T) { { name: "issue not pinned", tty: true, - opts: &UnpinOptions{SelectorArg: "20"}, + opts: &UnpinOptions{IssueNumber: 20}, httpStubs: func(reg *httpmock.Registry) { reg.Register( httpmock.GraphQL(`query IssueByNumber\b`), diff --git a/pkg/cmd/issue/view/view.go b/pkg/cmd/issue/view/view.go index 8e3aa604038..a9e25513bc9 100644 --- a/pkg/cmd/issue/view/view.go +++ b/pkg/cmd/issue/view/view.go @@ -12,8 +12,11 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/cli/cli/v2/api" "github.com/cli/cli/v2/internal/browser" + fd "github.com/cli/cli/v2/internal/featuredetection" + "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/text" + "github.com/cli/cli/v2/pkg/cmd/issue/shared" issueShared "github.com/cli/cli/v2/pkg/cmd/issue/shared" prShared "github.com/cli/cli/v2/pkg/cmd/pr/shared" "github.com/cli/cli/v2/pkg/cmdutil" @@ -28,8 +31,9 @@ type ViewOptions struct { IO *iostreams.IOStreams BaseRepo func() (ghrepo.Interface, error) Browser browser.Browser + Detector fd.Detector - SelectorArg string + IssueNumber int WebMode bool Comments bool Exporter cmdutil.Exporter @@ -55,13 +59,23 @@ func NewCmdView(f *cmdutil.Factory, runF func(*ViewOptions) error) *cobra.Comman `, "`"), Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - // support `-R, --repo` override - opts.BaseRepo = f.BaseRepo + issueNumber, baseRepo, err := shared.ParseIssueFromArg(args[0]) + if err != nil { + return err + } - if len(args) > 0 { - opts.SelectorArg = args[0] + // If the args provided the base repo then use that directly. + if baseRepo, present := baseRepo.Value(); present { + opts.BaseRepo = func() (ghrepo.Interface, error) { + return baseRepo, nil + } + } else { + // support `-R, --repo` override + opts.BaseRepo = f.BaseRepo } + opts.IssueNumber = issueNumber + if runF != nil { return runF(opts) } @@ -78,7 +92,7 @@ func NewCmdView(f *cmdutil.Factory, runF func(*ViewOptions) error) *cobra.Comman var defaultFields = []string{ "number", "url", "state", "createdAt", "title", "body", "author", "milestone", - "assignees", "labels", "projectCards", "reactionGroups", "lastComment", "stateReason", + "assignees", "labels", "reactionGroups", "lastComment", "stateReason", } func viewRun(opts *ViewOptions) error { @@ -87,6 +101,11 @@ func viewRun(opts *ViewOptions) error { return err } + baseRepo, err := opts.BaseRepo() + if err != nil { + return err + } + lookupFields := set.NewStringSet() if opts.Exporter != nil { lookupFields.AddValues(opts.Exporter.Fields()) @@ -98,12 +117,35 @@ func viewRun(opts *ViewOptions) error { lookupFields.Add("comments") lookupFields.Remove("lastComment") } + + // TODO projectsV1Deprecation + // Remove this section as we should no longer add projectCards + if opts.Detector == nil { + cachedClient := api.NewCachedHTTPClient(httpClient, time.Hour*24) + opts.Detector = fd.NewDetector(cachedClient, baseRepo.RepoHost()) + } + + projectsV1Support := opts.Detector.ProjectsV1() + if projectsV1Support == gh.ProjectsV1Supported { + lookupFields.Add("projectCards") + } } opts.IO.DetectTerminalTheme() opts.IO.StartProgressIndicator() - issue, baseRepo, err := findIssue(httpClient, opts.BaseRepo, opts.SelectorArg, lookupFields.ToSlice()) + lookupFields.Add("id") + + issue, err := issueShared.FindIssueOrPR(httpClient, baseRepo, opts.IssueNumber, lookupFields.ToSlice()) + if err != nil { + return err + } + + if lookupFields.Contains("comments") { + // FIXME: this re-fetches the comments connection even though the initial set of 100 were + // fetched in the previous request. + err = preloadIssueComments(httpClient, baseRepo, issue) + } opts.IO.StopProgressIndicator() if err != nil { var loadErr *issueShared.PartialLoadError @@ -143,24 +185,6 @@ func viewRun(opts *ViewOptions) error { return printRawIssuePreview(opts.IO.Out, issue) } -func findIssue(client *http.Client, baseRepoFn func() (ghrepo.Interface, error), selector string, fields []string) (*api.Issue, ghrepo.Interface, error) { - fieldSet := set.NewStringSet() - fieldSet.AddValues(fields) - fieldSet.Add("id") - - issue, repo, err := issueShared.IssueFromArgWithFields(client, baseRepoFn, selector, fieldSet.ToSlice()) - if err != nil { - return issue, repo, err - } - - if fieldSet.Contains("comments") { - // FIXME: this re-fetches the comments connection even though the initial set of 100 were - // fetched in the previous request. - err = preloadIssueComments(client, repo, issue) - } - return issue, repo, err -} - func printRawIssuePreview(out io.Writer, issue *api.Issue) error { assignees := issueAssigneeList(*issue) labels := issueLabelList(issue, nil) diff --git a/pkg/cmd/issue/view/view_test.go b/pkg/cmd/issue/view/view_test.go index e1798af9f83..391a288fb21 100644 --- a/pkg/cmd/issue/view/view_test.go +++ b/pkg/cmd/issue/view/view_test.go @@ -10,9 +10,11 @@ import ( "github.com/cli/cli/v2/internal/browser" "github.com/cli/cli/v2/internal/config" + fd "github.com/cli/cli/v2/internal/featuredetection" "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/run" + "github.com/cli/cli/v2/pkg/cmd/issue/argparsetest" "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/httpmock" "github.com/cli/cli/v2/pkg/iostreams" @@ -47,6 +49,11 @@ func TestJSONFields(t *testing.T) { }) } +func TestNewCmdView(t *testing.T) { + // Test shared parsing of issue number / URL. + argparsetest.TestArgParsing(t, NewCmdView) +} + func runCommand(rt http.RoundTripper, isTTY bool, cli string) (*test.CmdOut, error) { ios, _, stdout, stderr := iostreams.Test() ios.SetStdoutTTY(isTTY) @@ -116,7 +123,7 @@ func TestIssueView_web(t *testing.T) { return ghrepo.New("OWNER", "REPO"), nil }, WebMode: true, - SelectorArg: "123", + IssueNumber: 123, }) if err != nil { t.Errorf("error running command `issue view`: %v", err) @@ -273,7 +280,7 @@ func TestIssueView_tty_Preview(t *testing.T) { BaseRepo: func() (ghrepo.Interface, error) { return ghrepo.New("OWNER", "REPO"), nil }, - SelectorArg: "123", + IssueNumber: 123, } err := viewRun(&opts) @@ -490,3 +497,66 @@ func TestIssueView_nontty_Comments(t *testing.T) { }) } } + +// TODO projectsV1Deprecation +// Remove this test. +func TestProjectsV1Deprecation(t *testing.T) { + t.Run("when projects v1 is supported, is included in query", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + reg.Register( + httpmock.GraphQL(`projectCards`), + // Simulate a GraphQL error to early exit the test. + httpmock.StatusStringResponse(500, ""), + ) + + _, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + // Ignore the error because we have no way to really stub it without + // fully stubbing a GQL error structure in the request body. + _ = viewRun(&ViewOptions{ + IO: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + }, + + Detector: &fd.EnabledDetectorMock{}, + IssueNumber: 123, + }) + + // Verify that our request contained projectCards + reg.Verify(t) + }) + + t.Run("when projects v1 is not supported, is not included in query", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + reg := &httpmock.Registry{} + reg.Exclude(t, httpmock.GraphQL(`projectCards`)) + + _, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + // Ignore the error because we're not really interested in it. + _ = viewRun(&ViewOptions{ + IO: ios, + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + BaseRepo: func() (ghrepo.Interface, error) { + return ghrepo.New("OWNER", "REPO"), nil + }, + + Detector: &fd.DisabledDetectorMock{}, + IssueNumber: 123, + }) + + // Verify that our request contained projectCards + reg.Verify(t) + }) +} diff --git a/pkg/cmd/pr/create/create.go b/pkg/cmd/pr/create/create.go index 5f8979c11c1..7f960bce446 100644 --- a/pkg/cmd/pr/create/create.go +++ b/pkg/cmd/pr/create/create.go @@ -25,6 +25,7 @@ import ( "github.com/cli/cli/v2/pkg/cmdutil" "github.com/cli/cli/v2/pkg/iostreams" "github.com/cli/cli/v2/pkg/markdown" + o "github.com/cli/cli/v2/pkg/option" "github.com/spf13/cobra" ) @@ -72,18 +73,107 @@ type CreateOptions struct { DryRun bool } +// creationRefs is an interface that provides the necessary information for creating a pull request in the API. +// Upcasting to concrete implementations can provide further context on other operations (forking and pushing). +type creationRefs interface { + // QualifiedHeadRef returns a stringified form of the head ref, varying depending + // on whether the head ref is in the same repository as the base ref. If they are + // the same repository, we return the branch name only. If they are different repositories, + // we return the owner and branch name in the form :. + QualifiedHeadRef() string + // UnqualifiedHeadRef returns a head ref in the form of the branch name only. + UnqualifiedHeadRef() string + //BaseRef returns the base branch name. + BaseRef() string + + // While the only thing really required from an api.Repository is the repository ID, changing that + // would require changing the API function signatures, and the refactor that introduced this refs + // type is already large enough. + BaseRepo() *api.Repository +} + +type baseRefs struct { + baseRepo *api.Repository + baseBranchName string +} + +func (r baseRefs) BaseRef() string { + return r.baseBranchName +} + +func (r baseRefs) BaseRepo() *api.Repository { + return r.baseRepo +} + +// skipPushRefs indicate to handlePush that no pushing is required. +type skipPushRefs struct { + baseRefs + + qualifiedHeadRef shared.QualifiedHeadRef +} + +func (r skipPushRefs) QualifiedHeadRef() string { + return r.qualifiedHeadRef.String() +} + +func (r skipPushRefs) UnqualifiedHeadRef() string { + return r.qualifiedHeadRef.BranchName() +} + +// pushableRefs indicate to handlePush that pushing is required, +// and provide further information (HeadRepo) on where that push +// should go. +type pushableRefs struct { + baseRefs + + headRepo ghrepo.Interface + headBranchName string +} + +func (r pushableRefs) QualifiedHeadRef() string { + if ghrepo.IsSame(r.headRepo, r.baseRepo) { + return r.headBranchName + } + return fmt.Sprintf("%s:%s", r.headRepo.RepoOwner(), r.headBranchName) +} + +func (r pushableRefs) UnqualifiedHeadRef() string { + return r.headBranchName +} + +func (r pushableRefs) HeadRepo() ghrepo.Interface { + return r.headRepo +} + +// forkableRefs indicate to handlePush that forking is required before +// pushing. The expectation is that after forking, this is converted to +// pushableRefs. We could go very OOP and have a Fork method on this +// struct that returns a pushableRefs but then we'd need to embed an API client +// and it just seems nice that it is a simple bag of data. +type forkableRefs struct { + baseRefs + + qualifiedHeadRef shared.QualifiedHeadRef +} + +func (r forkableRefs) QualifiedHeadRef() string { + return r.qualifiedHeadRef.String() +} + +func (r forkableRefs) UnqualifiedHeadRef() string { + return r.qualifiedHeadRef.BranchName() +} + +// CreateContext stores contextual data about the creation process and is for building up enough +// data to create a pull request. type CreateContext struct { - // This struct stores contextual data about the creation process and is for building up enough - // data to create a pull request - RepoContext *ghContext.ResolvedRemotes - BaseRepo *api.Repository - HeadRepo ghrepo.Interface + ResolvedRemotes *ghContext.ResolvedRemotes + PRRefs creationRefs + // BaseTrackingBranch is perhaps a slightly leaky abstraction in the presence + // of PRRefs, but a huge amount of refactoring was done to introduce that struct, + // and this is a small price to pay for the convenience of not having to do a lot + // more design. BaseTrackingBranch string - BaseBranch string - HeadBranch string - HeadBranchLabel string - HeadRemote *ghContext.Remote - IsPushEnabled bool Client *api.Client GitClient *git.Client } @@ -113,6 +203,10 @@ func NewCmdCreate(f *cmdutil.Factory, runF func(*CreateOptions) error) *cobra.Co to push the branch and offer an option to fork the base repository. Use %[1]s--head%[1]s to explicitly skip any forking or pushing behavior. + %[1]s--head%[1]s supports %[1]s:%[1]s syntax to select a head repo owned by %[1]s%[1]s. + Using an organization as the %[1]s%[1]s is currently not supported. + For more information, see + A prompt will also ask for the title and the body of the pull request. Use %[1]s--title%[1]s and %[1]s--body%[1]s to skip this, or use %[1]s--fill%[1]s to autofill these values from git commits. It's important to notice that if the %[1]s--title%[1]s and/or %[1]s--body%[1]s are also provided @@ -310,8 +404,8 @@ func createRun(opts *CreateOptions) error { } existingPR, _, err := opts.Finder.Find(shared.FindOptions{ - Selector: ctx.HeadBranchLabel, - BaseBranch: ctx.BaseBranch, + Selector: ctx.PRRefs.QualifiedHeadRef(), + BaseBranch: ctx.PRRefs.BaseRef(), States: []string{"OPEN"}, Fields: []string{"url"}, }) @@ -321,7 +415,7 @@ func createRun(opts *CreateOptions) error { } if err == nil { return fmt.Errorf("a pull request for branch %q into branch %q already exists:\n%s", - ctx.HeadBranchLabel, ctx.BaseBranch, existingPR.URL) + ctx.PRRefs.QualifiedHeadRef(), ctx.PRRefs.BaseRef(), existingPR.URL) } message := "\nCreating pull request for %s into %s in %s\n\n" @@ -336,9 +430,9 @@ func createRun(opts *CreateOptions) error { if opts.IO.CanPrompt() { fmt.Fprintf(opts.IO.ErrOut, message, - cs.Cyan(ctx.HeadBranchLabel), - cs.Cyan(ctx.BaseBranch), - ghrepo.FullName(ctx.BaseRepo)) + cs.Cyan(ctx.PRRefs.QualifiedHeadRef()), + cs.Cyan(ctx.PRRefs.BaseRef()), + ghrepo.FullName(ctx.PRRefs.BaseRepo())) } if !opts.EditorMode && (opts.FillVerbose || opts.Autofill || opts.FillFirst || (opts.TitleProvided && opts.BodyProvided)) { @@ -346,7 +440,8 @@ func createRun(opts *CreateOptions) error { if err != nil { return err } - return submitPR(*opts, *ctx, *state) + // TODO wm: revisit project support + return submitPR(*opts, *ctx, *state, gh.ProjectsV1Supported) } if opts.RecoverFile != "" { @@ -361,7 +456,7 @@ func createRun(opts *CreateOptions) error { action = shared.SubmitDraftAction } - tpl := shared.NewTemplateManager(client.HTTP(), ctx.BaseRepo, opts.Prompter, opts.RootDirOverride, opts.RepoOverride == "", true) + tpl := shared.NewTemplateManager(client.HTTP(), ctx.PRRefs.BaseRepo(), opts.Prompter, opts.RootDirOverride, opts.RepoOverride == "", true) if opts.EditorMode { if opts.Template != "" { @@ -429,7 +524,7 @@ func createRun(opts *CreateOptions) error { } allowPreview := !state.HasMetadata() && shared.ValidURL(openURL) && !opts.DryRun - allowMetadata := ctx.BaseRepo.ViewerCanTriage() + allowMetadata := ctx.PRRefs.BaseRepo().ViewerCanTriage() action, err = shared.ConfirmPRSubmission(opts.Prompter, allowPreview, allowMetadata, state.Draft) if err != nil { return fmt.Errorf("unable to confirm: %w", err) @@ -439,10 +534,11 @@ func createRun(opts *CreateOptions) error { fetcher := &shared.MetadataFetcher{ IO: opts.IO, APIClient: client, - Repo: ctx.BaseRepo, + Repo: ctx.PRRefs.BaseRepo(), State: state, } - err = shared.MetadataSurvey(opts.Prompter, opts.IO, ctx.BaseRepo, fetcher, state) + // TODO wm: revisit project support + err = shared.MetadataSurvey(opts.Prompter, opts.IO, ctx.PRRefs.BaseRepo(), fetcher, state, gh.ProjectsV1Supported) if err != nil { return err } @@ -471,11 +567,13 @@ func createRun(opts *CreateOptions) error { if action == shared.SubmitDraftAction { state.Draft = true - return submitPR(*opts, *ctx, *state) + // TODO wm: revisit project support + return submitPR(*opts, *ctx, *state, gh.ProjectsV1Supported) } if action == shared.SubmitAction { - return submitPR(*opts, *ctx, *state) + // TODO wm: revisit project support + return submitPR(*opts, *ctx, *state, gh.ProjectsV1Supported) } err = errors.New("expected to cancel, preview, or submit") @@ -485,11 +583,7 @@ func createRun(opts *CreateOptions) error { var regexPattern = regexp.MustCompile(`(?m)^`) func initDefaultTitleBody(ctx CreateContext, state *shared.IssueMetadataState, useFirstCommit bool, addBody bool) error { - baseRef := ctx.BaseTrackingBranch - headRef := ctx.HeadBranch - gitClient := ctx.GitClient - - commits, err := gitClient.Commits(context.Background(), baseRef, headRef) + commits, err := ctx.GitClient.Commits(context.Background(), ctx.BaseTrackingBranch, ctx.PRRefs.UnqualifiedHeadRef()) if err != nil { return err } @@ -498,7 +592,7 @@ func initDefaultTitleBody(ctx CreateContext, state *shared.IssueMetadataState, u state.Title = commits[len(commits)-1].Title state.Body = commits[len(commits)-1].Body } else { - state.Title = humanize(headRef) + state.Title = humanize(ctx.PRRefs.UnqualifiedHeadRef()) var body strings.Builder for i := len(commits) - 1; i >= 0; i-- { fmt.Fprintf(&body, "- **%s**\n", commits[i].Title) @@ -518,103 +612,26 @@ func initDefaultTitleBody(ctx CreateContext, state *shared.IssueMetadataState, u return nil } -// TODO: Replace with the finder's PullRequestRefs struct -// trackingRef represents a ref for a remote tracking branch. -type trackingRef struct { - remoteName string - branchName string -} - -func (r trackingRef) String() string { - return "refs/remotes/" + r.remoteName + "/" + r.branchName -} - -func mustParseTrackingRef(text string) trackingRef { - parts := strings.SplitN(string(text), "/", 4) - // The only place this is called is tryDetermineTrackingRef, where we are reconstructing - // the same tracking ref we passed in. If it doesn't match the expected format, this is a - // programmer error we want to know about, so it's ok to panic. - if len(parts) != 4 { - panic(fmt.Errorf("tracking ref should have four parts: %s", text)) - } - if parts[0] != "refs" || parts[1] != "remotes" { - panic(fmt.Errorf("tracking ref should start with refs/remotes/: %s", text)) - } - - return trackingRef{ - remoteName: parts[2], - branchName: parts[3], - } -} - -// tryDetermineTrackingRef is intended to try and find a remote branch on the same commit as the currently checked out -// HEAD, i.e. the local branch. If there are multiple branches that might match, the first remote is chosen, which in -// practice is determined by the sorting algorithm applied much earlier in the process, roughly "upstream", "github", "origin", -// and then everything else unstably sorted. -func tryDetermineTrackingRef(gitClient *git.Client, remotes ghContext.Remotes, localBranchName string, headBranchConfig git.BranchConfig) (trackingRef, bool) { - // To try and determine the tracking ref for a local branch, we first construct a collection of refs - // that might be tracking, given the current branch's config, and the list of known remotes. - refsForLookup := []string{"HEAD"} - if headBranchConfig.RemoteName != "" && headBranchConfig.MergeRef != "" { - tr := trackingRef{ - remoteName: headBranchConfig.RemoteName, - branchName: strings.TrimPrefix(headBranchConfig.MergeRef, "refs/heads/"), - } - refsForLookup = append(refsForLookup, tr.String()) - } - - for _, remote := range remotes { - tr := trackingRef{ - remoteName: remote.Name, - branchName: localBranchName, - } - refsForLookup = append(refsForLookup, tr.String()) - } - - // Then we ask git for details about these refs, for example, refs/remotes/origin/trunk might return a hash - // for the remote tracking branch, trunk, for the remote, origin. If there is no ref, the git client returns - // no ref information. - // - // We also first check for the HEAD ref, so that we have the hash of the currently checked out commit. - resolvedRefs, _ := gitClient.ShowRefs(context.Background(), refsForLookup) - - // If there is more than one resolved ref, that means that at least one ref was found in addition to the HEAD. - if len(resolvedRefs) > 1 { - headRef := resolvedRefs[0] - for _, r := range resolvedRefs[1:] { - // If the hash of the remote ref doesn't match the hash of HEAD then the remote branch is not in the same - // state, so it can't be used. - if r.Hash != headRef.Hash { - continue - } - // Otherwise we can parse the returned ref into a tracking ref and return that - return mustParseTrackingRef(r.Name), true - } - } - - return trackingRef{}, false -} - func NewIssueState(ctx CreateContext, opts CreateOptions) (*shared.IssueMetadataState, error) { var milestoneTitles []string if opts.Milestone != "" { milestoneTitles = []string{opts.Milestone} } - meReplacer := shared.NewMeReplacer(ctx.Client, ctx.BaseRepo.RepoHost()) + meReplacer := shared.NewMeReplacer(ctx.Client, ctx.PRRefs.BaseRepo().RepoHost()) assignees, err := meReplacer.ReplaceSlice(opts.Assignees) if err != nil { return nil, err } state := &shared.IssueMetadataState{ - Type: shared.PRMetadata, - Reviewers: opts.Reviewers, - Assignees: assignees, - Labels: opts.Labels, - Projects: opts.Projects, - Milestones: milestoneTitles, - Draft: opts.IsDraft, + Type: shared.PRMetadata, + Reviewers: opts.Reviewers, + Assignees: assignees, + Labels: opts.Labels, + ProjectTitles: opts.Projects, + Milestones: milestoneTitles, + Draft: opts.IsDraft, } if opts.FillVerbose || opts.Autofill || opts.FillFirst || !opts.TitleProvided || !opts.BodyProvided { @@ -638,13 +655,14 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { if err != nil { return nil, err } - repoContext, err := ghContext.ResolveRemotesToRepos(remotes, client, opts.RepoOverride) + + resolvedRemotes, err := ghContext.ResolveRemotesToRepos(remotes, client, opts.RepoOverride) if err != nil { return nil, err } var baseRepo *api.Repository - if br, err := repoContext.BaseRepo(opts.IO); err == nil { + if br, err := resolvedRemotes.BaseRepo(opts.IO); err == nil { if r, ok := br.(*api.Repository); ok { baseRepo = r } else { @@ -659,137 +677,284 @@ func NewCreateContext(opts *CreateOptions) (*CreateContext, error) { return nil, err } - isPushEnabled := false - headBranch := opts.HeadBranch - headBranchLabel := opts.HeadBranch - if headBranch == "" { - headBranch, err = opts.Branch() + // This closure provides an easy way to instantiate a CreateContext with everything other than + // the refs. This probably indicates that CreateContext could do with some rework, but the refactor + // to introduce PRRefs is already large enough. + var newCreateContext = func(refs creationRefs) *CreateContext { + baseTrackingBranch := refs.BaseRef() + + // The baseTrackingBranch is used later for a command like: + // `git commit upstream/main feature` in order to create a PR message showing the commits + // between these two refs. I'm not really sure what is expected to happen if we don't have a remote, + // which seems like it would be possible with a command `gh pr create --repo owner/repo-that-is-not-a-remote`. + // In that case, we might just have a mess? In any case, this is what the old code did, so I don't want to change + // it as part of an already large refactor. + baseRemote, _ := resolvedRemotes.RemoteForRepo(baseRepo) + if baseRemote != nil { + baseTrackingBranch = fmt.Sprintf("%s/%s", baseRemote.Name, baseTrackingBranch) + } + + return &CreateContext{ + ResolvedRemotes: resolvedRemotes, + Client: client, + GitClient: opts.GitClient, + PRRefs: refs, + BaseTrackingBranch: baseTrackingBranch, + } + } + + // If the user provided a head branch we're going to use that without any interrogation + // of git. The value can take the form of or :. In the former case, the + // PR base and head repos are the same. In the latter case we don't know the head repo + // (though we could look it up in the API) but fortunately we don't need to because the API + // will resolve this for us when we create the pull request. This is possible because + // users can only have a single fork in their namespace, and organizations don't work at all with this ref format. + // + // Note that providing the head branch in this way indicates that we shouldn't push the branch, + // and we indicate that via the returned type as well. + if opts.HeadBranch != "" { + qualifiedHeadRef, err := shared.ParseQualifiedHeadRef(opts.HeadBranch) + if err != nil { + return nil, err + } + + branchConfig, err := opts.GitClient.ReadBranchConfig(context.Background(), qualifiedHeadRef.BranchName()) if err != nil { - return nil, fmt.Errorf("could not determine the current branch: %w", err) + return nil, err } - headBranchLabel = headBranch - isPushEnabled = true - } else if idx := strings.IndexRune(headBranch, ':'); idx >= 0 { - headBranch = headBranch[idx+1:] + + baseBranch := opts.BaseBranch + if baseBranch == "" { + baseBranch = branchConfig.MergeBase + } + if baseBranch == "" { + baseBranch = baseRepo.DefaultBranchRef.Name + } + + return newCreateContext(skipPushRefs{ + qualifiedHeadRef: qualifiedHeadRef, + baseRefs: baseRefs{ + baseRepo: baseRepo, + baseBranchName: baseBranch, + }, + }), nil } - gitClient := opts.GitClient - if ucc, err := gitClient.UncommittedChangeCount(context.Background()); err == nil && ucc > 0 { + if ucc, err := opts.GitClient.UncommittedChangeCount(context.Background()); err == nil && ucc > 0 { fmt.Fprintf(opts.IO.ErrOut, "Warning: %s\n", text.Pluralize(ucc, "uncommitted change")) } - var headRepo ghrepo.Interface - var headRemote *ghContext.Remote + // If the user didn't provide a head branch then we're gettin' real. We're going to interrogate git + // and try to create refs that are pushable. + currentBranch, err := opts.Branch() + if err != nil { + return nil, fmt.Errorf("could not determine the current branch: %w", err) + } - headBranchConfig, err := gitClient.ReadBranchConfig(context.Background(), headBranch) + branchConfig, err := opts.GitClient.ReadBranchConfig(context.Background(), currentBranch) if err != nil { return nil, err } - if isPushEnabled { - // TODO: This doesn't respect the @{push} revision resolution or triagular workflows assembled with - // remote.pushDefault, or branch..pushremote config settings. The finder's ParsePRRefs - // may be able to replace this function entirely. - if trackingRef, found := tryDetermineTrackingRef(gitClient, remotes, headBranch, headBranchConfig); found { - isPushEnabled = false - if r, err := remotes.FindByName(trackingRef.remoteName); err == nil { - headRepo = r - headRemote = r - headBranchLabel = trackingRef.branchName - if !ghrepo.IsSame(baseRepo, headRepo) { - headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), trackingRef.branchName) + + baseBranch := opts.BaseBranch + if baseBranch == "" { + baseBranch = branchConfig.MergeBase + } + if baseBranch == "" { + baseBranch = baseRepo.DefaultBranchRef.Name + } + + // First we check with the git information we have to see if we can figure out the default + // head repo and remote branch name. + defaultPRHead, err := shared.TryDetermineDefaultPRHead( + // We requested the branch config already, so let's cache that + shared.CachedBranchConfigGitConfigClient{ + CachedBranchConfig: branchConfig, + GitConfigClient: opts.GitClient, + }, + shared.NewRemoteToRepoResolver(opts.Remotes), + currentBranch, + ) + if err != nil { + return nil, err + } + + // The baseRefs are always going to be the same from now on. If I could make this immutable I would! + baseRefs := baseRefs{ + baseRepo: baseRepo, + baseBranchName: baseBranch, + } + + // If we were able to determine a head repo, then let's check that the remote tracking ref matches the SHA of + // HEAD. If it does, then we don't need to push, otherwise we'll need to ask the user to tell us where to push. + if headRepo, present := defaultPRHead.Repo.Value(); present { + // We may not find a remote because the git branch config may have a URL rather than a remote name. + // Ideally, we would return a sentinel error from RemoteForRepo that we could compare to, but the + // refactor that introduced this code was already large enough. + headRemote, _ := resolvedRemotes.RemoteForRepo(headRepo) + if headRemote != nil { + resolvedRefs, _ := opts.GitClient.ShowRefs( + context.Background(), + []string{ + "HEAD", + fmt.Sprintf("refs/remotes/%s/%s", headRemote.Name, defaultPRHead.BranchName), + }, + ) + + // Two refs returned means we can compare HEAD to the remote tracking branch. + // If we had a matching ref, then we can skip pushing. + refsMatch := len(resolvedRefs) == 2 && resolvedRefs[0].Hash == resolvedRefs[1].Hash + if refsMatch { + qualifiedHeadRef := shared.NewQualifiedHeadRefWithoutOwner(defaultPRHead.BranchName) + if headRepo.RepoOwner() != baseRepo.RepoOwner() { + qualifiedHeadRef = shared.NewQualifiedHeadRef(headRepo.RepoOwner(), defaultPRHead.BranchName) } + + return newCreateContext(skipPushRefs{ + qualifiedHeadRef: qualifiedHeadRef, + baseRefs: baseRefs, + }), nil } } } - // otherwise, ask the user for the head repository using info obtained from the API - if headRepo == nil && isPushEnabled && opts.IO.CanPrompt() { - pushableRepos, err := repoContext.HeadRepos() - if err != nil { - return nil, err + // If we didn't determine that the git indicated repo had the correct ref, we'll take a look at the other + // remotes and see whether any of them have the same SHA as HEAD. Now, at this point, you might be asking yourself: + // "Why didn't we collect all the SHAs with a single ShowRefs command above, for use in both cases?" + // ... + // That's because the code below has a bug that I've ported from the old code, in order to preserve the existing + // behaviour, and to limit the scope of an already large refactor. The intention of the original code was to loop + // over all the returned refs. However, as it turns out, our implementation of ShowRefs doesn't do that correctly. + // Since it provides the --verify flag, git will return the SHAs for refs up until it hits a ref that doesn't exist, + // at which point it bails out. + // + // Imagine you have a remotes "upstream" and "origin", and you have pushed your branch "feature" to "origin". Since + // the order of remotes is always guaranteed "upstream", "github", "origin", and then everything else unstably sorted, + // we will never get a SHA for origin, as refs/remotes/upstream/feature doesn't exist. + // + // Furthermore, when you really think about it, this code is a bit eager. What happens if you have the same SHA on + // remotes "origin" and "colleague", this will always offer origin. If it were "colleague-a" and "colleague-b", no + // order would be guaranteed between different invocations of pr create, because the order of remotes after "origin" + // is unstable sorted. + // + // All that said, this has been the behaviour for a long, long time, and I do not want to make other behavioural changes + // in what is mostly a refactor. + refsToLookup := []string{"HEAD"} + for _, remote := range remotes { + refsToLookup = append(refsToLookup, fmt.Sprintf("refs/remotes/%s/%s", remote.Name, currentBranch)) + } + + // Ignoring the error in this case is allowed because we may get refs and an error (see: --verify flag above). + // Ideally there would be a typed error to allow us to distinguish between an execution error and some refs + // not existing. However, this is too much to take on in an already large refactor. + refs, _ := opts.GitClient.ShowRefs(context.Background(), refsToLookup) + if len(refs) > 1 { + headRef := refs[0] + var firstMatchingRef o.Option[git.RemoteTrackingRef] + // Loop over all the refs, trying to find one that matches the SHA of HEAD. + for _, r := range refs[1:] { + if r.Hash == headRef.Hash { + remoteTrackingRef, err := git.ParseRemoteTrackingRef(r.Name) + if err != nil { + return nil, err + } + + firstMatchingRef = o.Some(remoteTrackingRef) + break + } } - if len(pushableRepos) == 0 { - pushableRepos, err = api.RepoFindForks(client, baseRepo, 3) + // If we found a matching ref, then we don't need to push. + if ref, present := firstMatchingRef.Value(); present { + remote, err := remotes.FindByName(ref.Remote) if err != nil { return nil, err } - } - - currentLogin, err := api.CurrentLoginName(client, baseRepo.RepoHost()) - if err != nil { - return nil, err - } - hasOwnFork := false - var pushOptions []string - for _, r := range pushableRepos { - pushOptions = append(pushOptions, ghrepo.FullName(r)) - if r.RepoOwner() == currentLogin { - hasOwnFork = true + qualifiedHeadRef := shared.NewQualifiedHeadRefWithoutOwner(ref.Branch) + if baseRepo.RepoOwner() != remote.RepoOwner() { + qualifiedHeadRef = shared.NewQualifiedHeadRef(remote.RepoOwner(), ref.Branch) } - } - if !hasOwnFork { - pushOptions = append(pushOptions, "Create a fork of "+ghrepo.FullName(baseRepo)) + return newCreateContext(skipPushRefs{ + qualifiedHeadRef: qualifiedHeadRef, + baseRefs: baseRefs, + }), nil } - pushOptions = append(pushOptions, "Skip pushing the branch") - pushOptions = append(pushOptions, "Cancel") + } + + // If we haven't got a repo by now, and we can't prompt then it's game over. + if !opts.IO.CanPrompt() { + fmt.Fprintln(opts.IO.ErrOut, "aborted: you must first push the current branch to a remote, or use the --head flag") + return nil, cmdutil.SilentError + } + + // Otherwise, hooray, prompting! - selectedOption, err := opts.Prompter.Select(fmt.Sprintf("Where should we push the '%s' branch?", headBranch), "", pushOptions) + // First, we're going to look at our remotes and decide whether there are any repos we can push to. + pushableRepos, err := resolvedRemotes.HeadRepos() + if err != nil { + return nil, err + } + + // If we couldn't find any pushable repos, then find forks of the base repo. + if len(pushableRepos) == 0 { + pushableRepos, err = api.RepoFindForks(client, baseRepo, 3) if err != nil { return nil, err } - - if selectedOption < len(pushableRepos) { - headRepo = pushableRepos[selectedOption] - if !ghrepo.IsSame(baseRepo, headRepo) { - headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), headBranch) - } - } else if pushOptions[selectedOption] == "Skip pushing the branch" { - isPushEnabled = false - } else if pushOptions[selectedOption] == "Cancel" { - return nil, cmdutil.CancelError - } else { - // "Create a fork of ..." - headBranchLabel = fmt.Sprintf("%s:%s", currentLogin, headBranch) - } } - if headRepo == nil && isPushEnabled && !opts.IO.CanPrompt() { - fmt.Fprintf(opts.IO.ErrOut, "aborted: you must first push the current branch to a remote, or use the --head flag") - return nil, cmdutil.SilentError + currentLogin, err := api.CurrentLoginName(client, baseRepo.RepoHost()) + if err != nil { + return nil, err } - baseBranch := opts.BaseBranch - if baseBranch == "" { - baseBranch = headBranchConfig.MergeBase - } - if baseBranch == "" { - baseBranch = baseRepo.DefaultBranchRef.Name + hasOwnFork := false + var pushOptions []string + for _, r := range pushableRepos { + pushOptions = append(pushOptions, ghrepo.FullName(r)) + if r.RepoOwner() == currentLogin { + hasOwnFork = true + } } - if headBranch == baseBranch && headRepo != nil && ghrepo.IsSame(baseRepo, headRepo) { - return nil, fmt.Errorf("must be on a branch named differently than %q", baseBranch) + + if !hasOwnFork { + pushOptions = append(pushOptions, fmt.Sprintf("Create a fork of %s", ghrepo.FullName(baseRepo))) } + pushOptions = append(pushOptions, "Skip pushing the branch") + pushOptions = append(pushOptions, "Cancel") - baseTrackingBranch := baseBranch - if baseRemote, err := remotes.FindByRepo(baseRepo.RepoOwner(), baseRepo.RepoName()); err == nil { - baseTrackingBranch = fmt.Sprintf("%s/%s", baseRemote.Name, baseBranch) + selectedOption, err := opts.Prompter.Select(fmt.Sprintf("Where should we push the '%s' branch?", currentBranch), "", pushOptions) + if err != nil { + return nil, err } - return &CreateContext{ - BaseRepo: baseRepo, - HeadRepo: headRepo, - BaseBranch: baseBranch, - BaseTrackingBranch: baseTrackingBranch, - HeadBranch: headBranch, - HeadBranchLabel: headBranchLabel, - HeadRemote: headRemote, - IsPushEnabled: isPushEnabled, - RepoContext: repoContext, - Client: client, - GitClient: gitClient, - }, nil + if selectedOption < len(pushableRepos) { + // A repository has been selected to push to. + return newCreateContext(pushableRefs{ + headRepo: pushableRepos[selectedOption], + headBranchName: currentBranch, + baseRefs: baseRefs, + }), nil + } else if pushOptions[selectedOption] == "Skip pushing the branch" { + // We're going to skip pushing the branch altogether, meaning, use whatever SHA is already pushed. + // It's not exactly clear what repo the user expects to use here for the HEAD, and maybe we should + // make that clear in the UX somehow, but in the old implementation as far as I can tell, this + // always meant "use the base repo". + return newCreateContext(skipPushRefs{ + qualifiedHeadRef: shared.NewQualifiedHeadRefWithoutOwner(currentBranch), + baseRefs: baseRefs, + }), nil + } else if pushOptions[selectedOption] == "Cancel" { + return nil, cmdutil.CancelError + } else { + // A fork should be created. + return newCreateContext(forkableRefs{ + qualifiedHeadRef: shared.NewQualifiedHeadRef(currentLogin, currentBranch), + baseRefs: baseRefs, + }), nil + } } func getRemotes(opts *CreateOptions) (ghContext.Remotes, error) { @@ -805,15 +970,15 @@ func getRemotes(opts *CreateOptions) (ghContext.Remotes, error) { return remotes, nil } -func submitPR(opts CreateOptions, ctx CreateContext, state shared.IssueMetadataState) error { +func submitPR(opts CreateOptions, ctx CreateContext, state shared.IssueMetadataState, projectV1Support gh.ProjectsV1Support) error { client := ctx.Client params := map[string]interface{}{ "title": state.Title, "body": state.Body, "draft": state.Draft, - "baseRefName": ctx.BaseBranch, - "headRefName": ctx.HeadBranchLabel, + "baseRefName": ctx.PRRefs.BaseRef(), + "headRefName": ctx.PRRefs.QualifiedHeadRef(), "maintainerCanModify": opts.MaintainerCanModify, } @@ -821,7 +986,7 @@ func submitPR(opts CreateOptions, ctx CreateContext, state shared.IssueMetadataS return errors.New("pull request title must not be blank") } - err := shared.AddMetadataToIssueParams(client, ctx.BaseRepo, params, &state) + err := shared.AddMetadataToIssueParams(client, ctx.PRRefs.BaseRepo(), params, &state, projectV1Support) if err != nil { return err } @@ -835,7 +1000,7 @@ func submitPR(opts CreateOptions, ctx CreateContext, state shared.IssueMetadataS } opts.IO.StartProgressIndicator() - pr, err := api.CreatePullRequest(client, ctx.BaseRepo, params) + pr, err := api.CreatePullRequest(client, ctx.PRRefs.BaseRepo(), params) opts.IO.StopProgressIndicator() if pr != nil { fmt.Fprintln(opts.IO.Out, pr.URL) @@ -867,8 +1032,8 @@ func renderPullRequestPlain(w io.Writer, params map[string]interface{}, state *s if len(state.Milestones) != 0 { fmt.Fprintf(w, "milestones:\t%v\n", strings.Join(state.Milestones, ", ")) } - if len(state.Projects) != 0 { - fmt.Fprintf(w, "projects:\t%v\n", strings.Join(state.Projects, ", ")) + if len(state.ProjectTitles) != 0 { + fmt.Fprintf(w, "projects:\t%v\n", strings.Join(state.ProjectTitles, ", ")) } fmt.Fprintf(w, "maintainerCanModify:\t%t\n", params["maintainerCanModify"]) fmt.Fprint(w, "body:\n") @@ -899,8 +1064,8 @@ func renderPullRequestTTY(io *iostreams.IOStreams, params map[string]interface{} if len(state.Milestones) != 0 { fmt.Fprintf(out, "%s: %s\n", cs.Bold("Milestones"), strings.Join(state.Milestones, ", ")) } - if len(state.Projects) != 0 { - fmt.Fprintf(out, "%s: %s\n", cs.Bold("Projects"), strings.Join(state.Projects, ", ")) + if len(state.ProjectTitles) != 0 { + fmt.Fprintf(out, "%s: %s\n", cs.Bold("Projects"), strings.Join(state.ProjectTitles, ", ")) } fmt.Fprintf(out, "%s: %t\n", cs.Bold("MaintainerCanModify"), params["maintainerCanModify"]) @@ -931,38 +1096,43 @@ func previewPR(opts CreateOptions, openURL string) error { } func handlePush(opts CreateOptions, ctx CreateContext) error { - didForkRepo := false - headRepo := ctx.HeadRepo - headRemote := ctx.HeadRemote - client := ctx.Client - gitClient := ctx.GitClient - - var err error - // if a head repository could not be determined so far, automatically create - // one by forking the base repository - if headRepo == nil && ctx.IsPushEnabled { + refs := ctx.PRRefs + forkableRefs, requiresFork := refs.(forkableRefs) + if requiresFork { opts.IO.StartProgressIndicator() - headRepo, err = api.ForkRepo(client, ctx.BaseRepo, "", "", false) + forkedRepo, err := api.ForkRepo(ctx.Client, forkableRefs.BaseRepo(), "", "", false) opts.IO.StopProgressIndicator() if err != nil { return fmt.Errorf("error forking repo: %w", err) } - didForkRepo = true + + refs = pushableRefs{ + headRepo: forkedRepo, + headBranchName: forkableRefs.qualifiedHeadRef.BranchName(), + baseRefs: baseRefs{ + baseRepo: forkableRefs.baseRepo, + baseBranchName: forkableRefs.baseBranchName, + }, + } } - if headRemote == nil && headRepo != nil { - headRemote, _ = ctx.RepoContext.RemoteForRepo(headRepo) + // We may have upcast to pushableRefs on fork, or we may have been passed an instance + // already. But if we haven't, then there's nothing more to do. + pushableRefs, ok := refs.(pushableRefs) + if !ok { + return nil } // There are two cases when an existing remote for the head repo will be - // missing: + // missing (and an error will be returned): // 1. the head repo was just created by auto-forking; // 2. an existing fork was discovered by querying the API. // In either case, we want to add the head repo as a new git remote so we // can push to it. We will try to add the head repo as the "origin" remote // and fallback to the "fork" remote if it is unavailable. Also, if the // base repo is the "origin" remote we will rename it "upstream". - if headRemote == nil && ctx.IsPushEnabled { + headRemote, _ := ctx.ResolvedRemotes.RemoteForRepo(pushableRefs.HeadRepo()) + if headRemote == nil { cfg, err := opts.Config() if err != nil { return err @@ -973,8 +1143,8 @@ func handlePush(opts CreateOptions, ctx CreateContext) error { return err } - cloneProtocol := cfg.GitProtocol(headRepo.RepoHost()).Value - headRepoURL := ghrepo.FormatRemoteURL(headRepo, cloneProtocol) + cloneProtocol := cfg.GitProtocol(pushableRefs.HeadRepo().RepoHost()).Value + headRepoURL := ghrepo.FormatRemoteURL(pushableRefs.HeadRepo(), cloneProtocol) gitClient := ctx.GitClient origin, _ := remotes.FindByName("origin") upstreamName := "upstream" @@ -985,7 +1155,7 @@ func handlePush(opts CreateOptions, ctx CreateContext) error { remoteName = "fork" } - if origin != nil && upstream == nil && ghrepo.IsSame(origin, ctx.BaseRepo) { + if origin != nil && upstream == nil && ghrepo.IsSame(origin, pushableRefs.BaseRepo()) { renameCmd, err := gitClient.Command(context.Background(), "remote", "rename", "origin", upstreamName) if err != nil { return err @@ -994,7 +1164,7 @@ func handlePush(opts CreateOptions, ctx CreateContext) error { return fmt.Errorf("error renaming origin remote: %w", err) } remoteName = "origin" - fmt.Fprintf(opts.IO.ErrOut, "Changed %s remote to %q\n", ghrepo.FullName(ctx.BaseRepo), upstreamName) + fmt.Fprintf(opts.IO.ErrOut, "Changed %s remote to %q\n", ghrepo.FullName(pushableRefs.BaseRepo()), upstreamName) } gitRemote, err := gitClient.AddRemote(context.Background(), remoteName, headRepoURL, []string{}) @@ -1002,10 +1172,10 @@ func handlePush(opts CreateOptions, ctx CreateContext) error { return fmt.Errorf("error adding remote: %w", err) } - fmt.Fprintf(opts.IO.ErrOut, "Added %s as remote %q\n", ghrepo.FullName(headRepo), remoteName) + fmt.Fprintf(opts.IO.ErrOut, "Added %s as remote %q\n", ghrepo.FullName(pushableRefs.HeadRepo()), remoteName) // Only mark `upstream` remote as default if `gh pr create` created the remote. - if didForkRepo { + if requiresFork { err := gitClient.SetRemoteResolution(context.Background(), upstreamName, "base") if err != nil { return fmt.Errorf("error setting upstream as default: %w", err) @@ -1013,52 +1183,46 @@ func handlePush(opts CreateOptions, ctx CreateContext) error { if opts.IO.IsStdoutTTY() { cs := opts.IO.ColorScheme() - fmt.Fprintf(opts.IO.ErrOut, "%s Repository %s set as the default repository. To learn more about the default repository, run: gh repo set-default --help\n", cs.WarningIcon(), cs.Bold(ghrepo.FullName(headRepo))) + fmt.Fprintf(opts.IO.ErrOut, "%s Repository %s set as the default repository. To learn more about the default repository, run: gh repo set-default --help\n", cs.WarningIcon(), cs.Bold(ghrepo.FullName(pushableRefs.HeadRepo()))) } } headRemote = &ghContext.Remote{ Remote: gitRemote, - Repo: headRepo, + Repo: pushableRefs.HeadRepo(), } } // automatically push the branch if it hasn't been pushed anywhere yet - if ctx.IsPushEnabled { - pushBranch := func() error { - w := NewRegexpWriter(opts.IO.ErrOut, gitPushRegexp, "") - defer w.Flush() - ref := fmt.Sprintf("HEAD:refs/heads/%s", ctx.HeadBranch) - bo := backoff.NewConstantBackOff(2 * time.Second) - ctx := context.Background() - return backoff.Retry(func() error { - if err := gitClient.Push(ctx, headRemote.Name, ref, git.WithStderr(w)); err != nil { - // Only retry if we have forked the repo else the push should succeed the first time. - if didForkRepo { - fmt.Fprintf(opts.IO.ErrOut, "waiting 2 seconds before retrying...\n") - return err - } - return backoff.Permanent(err) + pushBranch := func() error { + w := NewRegexpWriter(opts.IO.ErrOut, gitPushRegexp, "") + defer w.Flush() + ref := fmt.Sprintf("HEAD:refs/heads/%s", ctx.PRRefs.UnqualifiedHeadRef()) + bo := backoff.NewConstantBackOff(2 * time.Second) + root := context.Background() + return backoff.Retry(func() error { + if err := ctx.GitClient.Push(root, headRemote.Name, ref, git.WithStderr(w)); err != nil { + // Only retry if we have forked the repo else the push should succeed the first time. + if requiresFork { + fmt.Fprintf(opts.IO.ErrOut, "waiting 2 seconds before retrying...\n") + return err } - return nil - }, backoff.WithContext(backoff.WithMaxRetries(bo, 3), ctx)) - } - - err := pushBranch() - if err != nil { - return err - } + return backoff.Permanent(err) + } + return nil + }, backoff.WithContext(backoff.WithMaxRetries(bo, 3), root)) } - return nil + return pushBranch() } func generateCompareURL(ctx CreateContext, state shared.IssueMetadataState) (string, error) { u := ghrepo.GenerateRepoURL( - ctx.BaseRepo, + ctx.PRRefs.BaseRepo(), "compare/%s...%s?expand=1", - url.PathEscape(ctx.BaseBranch), url.PathEscape(ctx.HeadBranchLabel)) - url, err := shared.WithPrAndIssueQueryParams(ctx.Client, ctx.BaseRepo, u, state) + url.PathEscape(ctx.PRRefs.BaseRef()), url.PathEscape(ctx.PRRefs.QualifiedHeadRef())) + // TODO wm: revisit project support + url, err := shared.WithPrAndIssueQueryParams(ctx.Client, ctx.PRRefs.BaseRepo(), u, state, gh.ProjectsV1Supported) if err != nil { return "", err } diff --git a/pkg/cmd/pr/create/create_test.go b/pkg/cmd/pr/create/create_test.go index 55012d7ddd9..2a88b5eee13 100644 --- a/pkg/cmd/pr/create/create_test.go +++ b/pkg/cmd/pr/create/create_test.go @@ -2,7 +2,6 @@ package create import ( "encoding/json" - "errors" "fmt" "net/http" "os" @@ -607,7 +606,7 @@ func Test_createRun(t *testing.T) { `), }, { - name: "survey", + name: "select a specific branch to push to on prompt", tty: true, setup: func(opts *CreateOptions, t *testing.T) func() { opts.TitleProvided = true @@ -636,7 +635,9 @@ func Test_createRun(t *testing.T) { })) }, cmdStubs: func(cs *run.CommandStubber) { - cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "") + cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/feature") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "") + cs.Register("git show-ref --verify -- HEAD refs/remotes/origin/feature", 1, "") cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "") }, promptStubs: func(pm *prompter.PrompterMock) { @@ -651,6 +652,52 @@ func Test_createRun(t *testing.T) { expectedOut: "https://github.com/OWNER/REPO/pull/12\n", expectedErrOut: "\nCreating pull request for feature into master in OWNER/REPO\n\n", }, + { + name: "skip pushing to branch on prompt", + tty: true, + setup: func(opts *CreateOptions, t *testing.T) func() { + opts.TitleProvided = true + opts.BodyProvided = true + opts.Title = "my title" + opts.Body = "my body" + return func() {} + }, + httpStubs: func(reg *httpmock.Registry, t *testing.T) { + reg.StubRepoResponse("OWNER", "REPO") + reg.Register( + httpmock.GraphQL(`query UserCurrent\b`), + httpmock.StringResponse(`{"data": {"viewer": {"login": "OWNER"} } }`)) + reg.Register( + httpmock.GraphQL(`mutation PullRequestCreate\b`), + httpmock.GraphQLMutation(` + { "data": { "createPullRequest": { "pullRequest": { + "URL": "https://github.com/OWNER/REPO/pull/12" + } } } }`, func(input map[string]interface{}) { + assert.Equal(t, "REPOID", input["repositoryId"].(string)) + assert.Equal(t, "my title", input["title"].(string)) + assert.Equal(t, "my body", input["body"].(string)) + assert.Equal(t, "master", input["baseRefName"].(string)) + assert.Equal(t, "feature", input["headRefName"].(string)) + assert.Equal(t, false, input["draft"].(bool)) + })) + }, + cmdStubs: func(cs *run.CommandStubber) { + cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/feature") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "") + cs.Register("git show-ref --verify -- HEAD refs/remotes/origin/feature", 1, "") + }, + promptStubs: func(pm *prompter.PrompterMock) { + pm.SelectFunc = func(p, _ string, opts []string) (int, error) { + if p == "Where should we push the 'feature' branch?" { + return prompter.IndexFor(opts, "Skip pushing the branch") + } else { + return -1, prompter.NoSuchPromptErr(p) + } + } + }, + expectedOut: "https://github.com/OWNER/REPO/pull/12\n", + expectedErrOut: "\nCreating pull request for feature into master in OWNER/REPO\n\n", + }, { name: "project v2", tty: true, @@ -699,7 +746,9 @@ func Test_createRun(t *testing.T) { })) }, cmdStubs: func(cs *run.CommandStubber) { - cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "") + cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/feature") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "") cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "") }, promptStubs: func(pm *prompter.PrompterMock) { @@ -745,7 +794,9 @@ func Test_createRun(t *testing.T) { })) }, cmdStubs: func(cs *run.CommandStubber) { - cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "") + cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/feature") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "") cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "") }, promptStubs: func(pm *prompter.PrompterMock) { @@ -794,7 +845,10 @@ func Test_createRun(t *testing.T) { })) }, cmdStubs: func(cs *run.CommandStubber) { - cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "") + cs.Register("git rev-parse --symbolic-full-name feature@{push}", 1, "") + cs.Register("git config remote.pushDefault", 1, "") + cs.Register("git config push.default", 1, "") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "") cs.Register("git remote rename origin upstream", 0, "") cs.Register(`git remote add origin https://github.com/monalisa/REPO.git`, 0, "") cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "") @@ -853,10 +907,10 @@ func Test_createRun(t *testing.T) { })) }, cmdStubs: func(cs *run.CommandStubber) { - cs.Register("git show-ref --verify", 0, heredoc.Doc(` - deadbeef HEAD - deadb00f refs/remotes/upstream/feature - deadbeef refs/remotes/origin/feature`)) // determineTrackingBranch + cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/feature") + cs.Register("git show-ref --verify -- HEAD refs/remotes/origin/feature", 0, heredoc.Doc(` + deadbeef HEAD + deadbeef refs/remotes/origin/feature`)) }, expectedOut: "https://github.com/OWNER/REPO/pull/12\n", expectedErrOut: "\nCreating pull request for monalisa:feature into master in OWNER/REPO\n\n", @@ -889,11 +943,12 @@ func Test_createRun(t *testing.T) { cs.Register(`git config --get-regexp \^branch\\\.feature\\\.`, 0, heredoc.Doc(` branch.feature.remote origin branch.feature.merge refs/heads/my-feat2 - `)) // determineTrackingBranch - cs.Register("git show-ref --verify", 0, heredoc.Doc(` + `)) + cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/my-feat2") + cs.Register("git show-ref --verify -- HEAD refs/remotes/origin/my-feat2", 0, heredoc.Doc(` deadbeef HEAD deadbeef refs/remotes/origin/my-feat2 - `)) // determineTrackingBranch + `)) }, expectedOut: "https://github.com/OWNER/REPO/pull/12\n", expectedErrOut: "\nCreating pull request for my-feat2 into master in OWNER/REPO\n\n", @@ -1073,8 +1128,10 @@ func Test_createRun(t *testing.T) { httpmock.StringResponse(`{"data": {"viewer": {"login": "OWNER"} } }`)) }, cmdStubs: func(cs *run.CommandStubber) { - cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "") cs.Register(`git( .+)? log( .+)? origin/master\.\.\.feature`, 0, "") + cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/feature") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "") cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "") }, promptStubs: func(pm *prompter.PrompterMock) { @@ -1105,8 +1162,10 @@ func Test_createRun(t *testing.T) { mockRetrieveProjects(t, reg) }, cmdStubs: func(cs *run.CommandStubber) { - cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, "") cs.Register(`git( .+)? log( .+)? origin/master\.\.\.feature`, 0, "") + cs.Register("git rev-parse --symbolic-full-name feature@{push}", 0, "refs/remotes/origin/feature") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 1, "") cs.Register(`git push --set-upstream origin HEAD:refs/heads/feature`, 0, "") }, promptStubs: func(pm *prompter.PrompterMock) { @@ -1271,31 +1330,6 @@ func Test_createRun(t *testing.T) { }, wantErr: "cannot open in browser: maximum URL length exceeded", }, - { - name: "no local git repo", - setup: func(opts *CreateOptions, t *testing.T) func() { - opts.Title = "My PR" - opts.TitleProvided = true - opts.Body = "" - opts.BodyProvided = true - opts.HeadBranch = "feature" - opts.RepoOverride = "OWNER/REPO" - opts.Remotes = func() (context.Remotes, error) { - return nil, errors.New("not a git repository") - } - return func() {} - }, - httpStubs: func(reg *httpmock.Registry, t *testing.T) { - reg.Register( - httpmock.GraphQL(`mutation PullRequestCreate\b`), - httpmock.StringResponse(` - { "data": { "createPullRequest": { "pullRequest": { - "URL": "https://github.com/OWNER/REPO/pull/12" - } } } } - `)) - }, - expectedOut: "https://github.com/OWNER/REPO/pull/12\n", - }, { name: "single commit title and body are used", tty: true, @@ -1520,19 +1554,45 @@ func Test_createRun(t *testing.T) { branch.task1.remote origin branch.task1.merge refs/heads/task1 branch.task1.gh-merge-base feature/feat2`)) // ReadBranchConfig - cs.Register(`git show-ref --verify`, 0, heredoc.Doc(` + cs.Register("git rev-parse --symbolic-full-name task1@{push}", 0, "refs/remotes/origin/task1") + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/task1`, 0, heredoc.Doc(` deadbeef HEAD - deadb00f refs/remotes/upstream/feature/feat2 - deadbeef refs/remotes/origin/task1`)) // determineTrackingBranch + deadbeef refs/remotes/origin/task1`)) }, expectedOut: "https://github.com/OWNER/REPO/pull/12\n", expectedErrOut: "\nCreating pull request for monalisa:task1 into feature/feat2 in OWNER/REPO\n\n", }, + { + name: "--head contains : syntax", + httpStubs: func(reg *httpmock.Registry, t *testing.T) { + reg.Register( + httpmock.GraphQL(`mutation PullRequestCreate\b`), + httpmock.GraphQLMutation(` + { "data": { "createPullRequest": { "pullRequest": { + "URL": "https://github.com/OWNER/REPO/pull/12" + } } } }`, + func(input map[string]interface{}) { + assert.Equal(t, "REPOID", input["repositoryId"]) + assert.Equal(t, "my title", input["title"]) + assert.Equal(t, "my body", input["body"]) + assert.Equal(t, "master", input["baseRefName"]) + assert.Equal(t, "otherowner:feature", input["headRefName"]) + })) + }, + setup: func(opts *CreateOptions, t *testing.T) func() { + opts.TitleProvided = true + opts.BodyProvided = true + opts.Title = "my title" + opts.Body = "my body" + opts.HeadBranch = "otherowner:feature" + return func() {} + }, + expectedOut: "https://github.com/OWNER/REPO/pull/12\n", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { branch := "feature" - reg := &httpmock.Registry{} reg.StubRepoInfoResponse("OWNER", "REPO", "master") defer reg.Verify(t) @@ -1548,7 +1608,7 @@ func Test_createRun(t *testing.T) { cs, cmdTeardown := run.Stub() defer cmdTeardown(t) - cs.Register(`git status --porcelain`, 0, "") + if !tt.customBranchConfig { cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "") } @@ -1599,6 +1659,10 @@ func Test_createRun(t *testing.T) { } defer cleanSetup() + if opts.HeadBranch == "" { + cs.Register(`git status --porcelain`, 0, "") + } + err := createRun(&opts) output := &test.CmdOut{ OutBuf: stdout, @@ -1622,109 +1686,166 @@ func Test_createRun(t *testing.T) { } } -func Test_tryDetermineTrackingRef(t *testing.T) { - tests := []struct { - name string - cmdStubs func(*run.CommandStubber) - headBranchConfig git.BranchConfig - remotes context.Remotes - expectedTrackingRef trackingRef - expectedFound bool - }{ - { - name: "empty", - cmdStubs: func(cs *run.CommandStubber) { - cs.Register(`git show-ref --verify -- HEAD`, 0, "abc HEAD") - }, - headBranchConfig: git.BranchConfig{}, - expectedTrackingRef: trackingRef{}, - expectedFound: false, +func TestRemoteGuessing(t *testing.T) { + // Given git config does not provide the necessary info to determine a remote + cs, cmdTeardown := run.Stub() + defer cmdTeardown(t) + + cs.Register(`git status --porcelain`, 0, "") + cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "") + cs.Register(`git rev-parse --symbolic-full-name feature@{push}`, 1, "") + cs.Register("git config remote.pushDefault", 1, "") + cs.Register("git config push.default", 1, "") + + // And Given there is a remote on a SHA that matches the current HEAD + cs.Register(`git show-ref --verify -- HEAD refs/remotes/upstream/feature refs/remotes/origin/feature`, 0, heredoc.Doc(` + deadbeef HEAD + deadb00f refs/remotes/upstream/feature + deadbeef refs/remotes/origin/feature`)) + + // When the command is run + reg := &httpmock.Registry{} + reg.StubRepoInfoResponse("OWNER", "REPO", "master") + defer reg.Verify(t) + + reg.Register( + httpmock.GraphQL(`mutation PullRequestCreate\b`), + httpmock.GraphQLMutation(` + { "data": { "createPullRequest": { "pullRequest": { + "URL": "https://github.com/OWNER/REPO/pull/12" + } } } }`, func(input map[string]interface{}) { + assert.Equal(t, "REPOID", input["repositoryId"].(string)) + assert.Equal(t, "master", input["baseRefName"].(string)) + assert.Equal(t, "OTHEROWNER:feature", input["headRefName"].(string)) + })) + + ios, _, _, _ := iostreams.Test() + + opts := CreateOptions{ + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil }, - { - name: "no match", - cmdStubs: func(cs *run.CommandStubber) { - cs.Register("git show-ref --verify -- HEAD refs/remotes/upstream/feature refs/remotes/origin/feature", 0, "abc HEAD\nbca refs/remotes/upstream/feature") - }, - headBranchConfig: git.BranchConfig{}, - remotes: context.Remotes{ - &context.Remote{ - Remote: &git.Remote{Name: "upstream"}, - Repo: ghrepo.New("octocat", "Spoon-Knife"), - }, - &context.Remote{ - Remote: &git.Remote{Name: "origin"}, - Repo: ghrepo.New("hubot", "Spoon-Knife"), - }, - }, - expectedTrackingRef: trackingRef{}, - expectedFound: false, + Config: func() (gh.Config, error) { + return config.NewBlankConfig(), nil }, - { - name: "match", - cmdStubs: func(cs *run.CommandStubber) { - cs.Register(`git show-ref --verify -- HEAD refs/remotes/upstream/feature refs/remotes/origin/feature$`, 0, heredoc.Doc(` - deadbeef HEAD - deadb00f refs/remotes/upstream/feature - deadbeef refs/remotes/origin/feature - `)) - }, - headBranchConfig: git.BranchConfig{}, - remotes: context.Remotes{ - &context.Remote{ - Remote: &git.Remote{Name: "upstream"}, - Repo: ghrepo.New("octocat", "Spoon-Knife"), + Browser: &browser.Stub{}, + IO: ios, + Prompter: &prompter.PrompterMock{}, + GitClient: &git.Client{ + GhPath: "some/path/gh", + GitPath: "some/path/git", + }, + Finder: shared.NewMockFinder("feature", nil, nil), + Remotes: func() (context.Remotes, error) { + return context.Remotes{ + { + Remote: &git.Remote{ + Name: "upstream", + Resolved: "base", + }, + Repo: ghrepo.New("OWNER", "REPO"), }, - &context.Remote{ - Remote: &git.Remote{Name: "origin"}, - Repo: ghrepo.New("hubot", "Spoon-Knife"), + { + Remote: &git.Remote{ + Name: "origin", + }, + Repo: ghrepo.New("OTHEROWNER", "REPO-FORK"), }, - }, - expectedTrackingRef: trackingRef{ - remoteName: "origin", - branchName: "feature", - }, - expectedFound: true, + }, nil }, - { - name: "respect tracking config", - cmdStubs: func(cs *run.CommandStubber) { - cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/great-feat refs/remotes/origin/feature$`, 0, heredoc.Doc(` - deadbeef HEAD - deadb00f refs/remotes/origin/feature - `)) - }, - headBranchConfig: git.BranchConfig{ - RemoteName: "origin", - MergeRef: "refs/heads/great-feat", - }, - remotes: context.Remotes{ - &context.Remote{ - Remote: &git.Remote{Name: "origin"}, - Repo: ghrepo.New("hubot", "Spoon-Knife"), - }, - }, - expectedTrackingRef: trackingRef{}, - expectedFound: false, + Branch: func() (string, error) { + return "feature", nil }, + + TitleProvided: true, + BodyProvided: true, + Title: "my title", + Body: "my body", } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cs, cmdTeardown := run.Stub() - defer cmdTeardown(t) - tt.cmdStubs(cs) + require.NoError(t, createRun(&opts)) - gitClient := &git.Client{ - GhPath: "some/path/gh", - GitPath: "some/path/git", - } + // Then guessed remote is used for the PR head, + // which annoyingly, is asserted above on the line: + // assert.Equal(t, "OTHEROWNER:feature", input["headRefName"].(string)) + // + // This is because OTHEROWNER relates to the "origin" remote, which has a + // SHA that matches the HEAD ref in the `git show-ref` output. +} - ref, found := tryDetermineTrackingRef(gitClient, tt.remotes, "feature", tt.headBranchConfig) +func TestNoRepoCanBeDetermined(t *testing.T) { + // Given no head repo can be determined from git config + cs, cmdTeardown := run.Stub() + defer cmdTeardown(t) - assert.Equal(t, tt.expectedTrackingRef, ref) - assert.Equal(t, tt.expectedFound, found) - }) + cs.Register(`git status --porcelain`, 0, "") + cs.Register(`git config --get-regexp \^branch\\\..+\\\.\(remote\|merge\|pushremote\|gh-merge-base\)\$`, 0, "") + cs.Register(`git rev-parse --symbolic-full-name feature@{push}`, 1, "") + cs.Register("git config remote.pushDefault", 1, "") + cs.Register("git config push.default", 1, "") + + // And Given there is no remote on the correct SHA + cs.Register(`git show-ref --verify -- HEAD refs/remotes/origin/feature`, 0, heredoc.Doc(` + deadbeef HEAD + deadb00f refs/remotes/origin/feature`)) + + // When the command is run with no TTY + reg := &httpmock.Registry{} + reg.StubRepoInfoResponse("OWNER", "REPO", "master") + defer reg.Verify(t) + + ios, _, _, stderr := iostreams.Test() + + opts := CreateOptions{ + HttpClient: func() (*http.Client, error) { + return &http.Client{Transport: reg}, nil + }, + Config: func() (gh.Config, error) { + return config.NewBlankConfig(), nil + }, + Browser: &browser.Stub{}, + IO: ios, + Prompter: &prompter.PrompterMock{}, + GitClient: &git.Client{ + GhPath: "some/path/gh", + GitPath: "some/path/git", + }, + Finder: shared.NewMockFinder("feature", nil, nil), + Remotes: func() (context.Remotes, error) { + return context.Remotes{ + { + Remote: &git.Remote{ + Name: "origin", + Resolved: "base", + }, + Repo: ghrepo.New("OWNER", "REPO"), + }, + }, nil + }, + Branch: func() (string, error) { + return "feature", nil + }, + + TitleProvided: true, + BodyProvided: true, + Title: "my title", + Body: "my body", + } + + // When we run the command + err := createRun(&opts) + + // Then create fails + require.Equal(t, cmdutil.SilentError, err) + assert.Equal(t, "aborted: you must first push the current branch to a remote, or use the --head flag\n", stderr.String()) +} + +func mustParseQualifiedHeadRef(ref string) shared.QualifiedHeadRef { + parsed, err := shared.ParseQualifiedHeadRef(ref) + if err != nil { + panic(err) } + return parsed } func Test_generateCompareURL(t *testing.T) { @@ -1738,9 +1859,13 @@ func Test_generateCompareURL(t *testing.T) { { name: "basic", ctx: CreateContext{ - BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), - BaseBranch: "main", - HeadBranchLabel: "feature", + PRRefs: &skipPushRefs{ + qualifiedHeadRef: shared.NewQualifiedHeadRefWithoutOwner("feature"), + baseRefs: baseRefs{ + baseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + baseBranchName: "main", + }, + }, }, want: "https://github.com/OWNER/REPO/compare/main...feature?body=&expand=1", wantErr: false, @@ -1748,9 +1873,13 @@ func Test_generateCompareURL(t *testing.T) { { name: "with labels", ctx: CreateContext{ - BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), - BaseBranch: "a", - HeadBranchLabel: "b", + PRRefs: &skipPushRefs{ + qualifiedHeadRef: shared.NewQualifiedHeadRefWithoutOwner("b"), + baseRefs: baseRefs{ + baseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + baseBranchName: "a", + }, + }, }, state: shared.IssueMetadataState{ Labels: []string{"one", "two three"}, @@ -1761,35 +1890,47 @@ func Test_generateCompareURL(t *testing.T) { { name: "'/'s in branch names/labels are percent-encoded", ctx: CreateContext{ - BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), - BaseBranch: "main/trunk", - HeadBranchLabel: "owner:feature", + PRRefs: &skipPushRefs{ + qualifiedHeadRef: mustParseQualifiedHeadRef("ORIGINOWNER:feature"), + baseRefs: baseRefs{ + baseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "UPSTREAMOWNER"}}, "github.com"), + baseBranchName: "main/trunk", + }, + }, }, - want: "https://github.com/OWNER/REPO/compare/main%2Ftrunk...owner:feature?body=&expand=1", + want: "https://github.com/UPSTREAMOWNER/REPO/compare/main%2Ftrunk...ORIGINOWNER:feature?body=&expand=1", wantErr: false, }, { name: "Any of !'(),; but none of $&+=@ and : in branch names/labels are percent-encoded ", /* - - Technically, per section 3.3 of RFC 3986, none of !$&'()*+,;= (sub-delims) and :[]@ (part of gen-delims) in path segments are optionally percent-encoded, but url.PathEscape percent-encodes !'(),; anyway - - !$&'()+,;=@ is a valid Git branch name—essentially RFC 3986 sub-delims without * and gen-delims without :/?#[] - - : is GitHub separator between a fork name and a branch name - - See https://github.com/golang/go/issues/27559. + - Technically, per section 3.3 of RFC 3986, none of !$&'()*+,;= (sub-delims) and :[]@ (part of gen-delims) in path segments are optionally percent-encoded, but url.PathEscape percent-encodes !'(),; anyway + - !$&'()+,;=@ is a valid Git branch name—essentially RFC 3986 sub-delims without * and gen-delims without :/?#[] + - : is GitHub separator between a fork name and a branch name + - See https://github.com/golang/go/issues/27559. */ ctx: CreateContext{ - BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), - BaseBranch: "main/trunk", - HeadBranchLabel: "owner:!$&'()+,;=@", + PRRefs: &skipPushRefs{ + qualifiedHeadRef: mustParseQualifiedHeadRef("ORIGINOWNER:!$&'()+,;=@"), + baseRefs: baseRefs{ + baseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "UPSTREAMOWNER"}}, "github.com"), + baseBranchName: "main/trunk", + }, + }, }, - want: "https://github.com/OWNER/REPO/compare/main%2Ftrunk...owner:%21$&%27%28%29+%2C%3B=@?body=&expand=1", + want: "https://github.com/UPSTREAMOWNER/REPO/compare/main%2Ftrunk...ORIGINOWNER:%21$&%27%28%29+%2C%3B=@?body=&expand=1", wantErr: false, }, { name: "with template", ctx: CreateContext{ - BaseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), - BaseBranch: "main", - HeadBranchLabel: "feature", + PRRefs: &skipPushRefs{ + qualifiedHeadRef: shared.NewQualifiedHeadRefWithoutOwner("feature"), + baseRefs: baseRefs{ + baseRepo: api.InitRepoHostname(&api.Repository{Name: "REPO", Owner: api.RepositoryOwner{Login: "OWNER"}}, "github.com"), + baseBranchName: "main", + }, + }, }, state: shared.IssueMetadataState{ Template: "story.md", diff --git a/pkg/cmd/pr/shared/editable.go b/pkg/cmd/pr/shared/editable.go index cec3bfe8c9c..0bebb999ac0 100644 --- a/pkg/cmd/pr/shared/editable.go +++ b/pkg/cmd/pr/shared/editable.go @@ -381,7 +381,8 @@ func FetchOptions(client *api.Client, repo ghrepo.Interface, editable *Editable) Reviewers: editable.Reviewers.Edited, Assignees: editable.Assignees.Edited, Labels: editable.Labels.Edited, - Projects: editable.Projects.Edited, + ProjectsV1: editable.Projects.Edited, + ProjectsV2: editable.Projects.Edited, Milestones: editable.Milestone.Edited, } metadata, err := api.RepoMetadata(client, repo, input) diff --git a/pkg/cmd/pr/shared/find_refs_resolution.go b/pkg/cmd/pr/shared/find_refs_resolution.go new file mode 100644 index 00000000000..833075af818 --- /dev/null +++ b/pkg/cmd/pr/shared/find_refs_resolution.go @@ -0,0 +1,394 @@ +package shared + +import ( + "context" + "fmt" + "net/url" + "strings" + + ghContext "github.com/cli/cli/v2/context" + "github.com/cli/cli/v2/git" + + "github.com/cli/cli/v2/internal/ghrepo" + o "github.com/cli/cli/v2/pkg/option" +) + +// QualifiedHeadRef represents a git branch with an optional owner, used +// for the head of a pull request. For example, within a single repository, +// we would expect a PR to have a head ref of no owner, and a branch name. +// However, for cross-repository pull requests, we would expect a head ref +// with an owner and a branch name. In string form this is represented as +// :. The GitHub API is able to interpret this format in order +// to discover the correct fork repository. +// +// In other parts of the code, you may see this refered to as a HeadLabel. +type QualifiedHeadRef struct { + owner o.Option[string] + branchName string +} + +// NewQualifiedHeadRef creates a QualifiedHeadRef. If the empty string is provided +// for the owner, it will be treated as None. +func NewQualifiedHeadRef(owner string, branchName string) QualifiedHeadRef { + return QualifiedHeadRef{ + owner: o.SomeIfNonZero(owner), + branchName: branchName, + } +} + +func NewQualifiedHeadRefWithoutOwner(branchName string) QualifiedHeadRef { + return QualifiedHeadRef{ + owner: o.None[string](), + branchName: branchName, + } +} + +// ParseQualifiedHeadRef takes strings of the form : or +// and returns a QualifiedHeadRef. If the form : is used, +// the owner is set to the value of , and the branch name is set to +// the value of . If the form is used, the owner is set to +// None, and the branch name is set to the value of . +// +// This does no further error checking about the validity of a ref, so +// it is not safe to assume the ref is truly a valid ref, e.g. "my~bad:ref?" +// is going to result in a nonsense result. +func ParseQualifiedHeadRef(ref string) (QualifiedHeadRef, error) { + if !strings.Contains(ref, ":") { + return NewQualifiedHeadRefWithoutOwner(ref), nil + } + + parts := strings.Split(ref, ":") + if len(parts) != 2 { + return QualifiedHeadRef{}, fmt.Errorf("invalid qualified head ref format '%s'", ref) + } + + return NewQualifiedHeadRef(parts[0], parts[1]), nil +} + +// A QualifiedHeadRef without an owner returns , while a QualifiedHeadRef +// with an owner returns :. +func (r QualifiedHeadRef) String() string { + if owner, present := r.owner.Value(); present { + return fmt.Sprintf("%s:%s", owner, r.branchName) + } + return r.branchName +} + +func (r QualifiedHeadRef) BranchName() string { + return r.branchName +} + +// PRFindRefs represents the necessary data to find a pull request from the API. +type PRFindRefs struct { + qualifiedHeadRef QualifiedHeadRef + + baseRepo ghrepo.Interface + // baseBranchName is an optional branch name, because it is not required for + // finding a pull request, only for disambiguation if multiple pull requests + // contain the same head ref. + baseBranchName o.Option[string] +} + +// QualifiedHeadRef returns a stringified form of the head ref, varying depending +// on whether the head ref is in the same repository as the base ref. If they are +// the same repository, we return the branch name only. If they are different repositories, +// we return the owner and branch name in the form :. +func (r PRFindRefs) QualifiedHeadRef() string { + return r.qualifiedHeadRef.String() +} + +func (r PRFindRefs) UnqualifiedHeadRef() string { + return r.qualifiedHeadRef.BranchName() +} + +// Matches checks whether the provided baseBranchName and headRef match the refs. +// It is used to determine whether Pull Requests returned from the API +func (r PRFindRefs) Matches(baseBranchName, qualifiedHeadRef string) bool { + headMatches := qualifiedHeadRef == r.QualifiedHeadRef() + baseMatches := r.baseBranchName.IsNone() || baseBranchName == r.baseBranchName.Unwrap() + return headMatches && baseMatches +} + +func (r PRFindRefs) BaseRepo() ghrepo.Interface { + return r.baseRepo +} + +type RemoteNameToRepoFn func(remoteName string) (ghrepo.Interface, error) + +// PullRequestFindRefsResolver interrogates git configuration to try and determine +// a head repository and a remote branch name, from a local branch name. +type PullRequestFindRefsResolver struct { + GitConfigClient GitConfigClient + RemoteNameToRepoFn RemoteNameToRepoFn +} + +func NewPullRequestFindRefsResolver(gitConfigClient GitConfigClient, remotesFn func() (ghContext.Remotes, error)) PullRequestFindRefsResolver { + return PullRequestFindRefsResolver{ + GitConfigClient: gitConfigClient, + RemoteNameToRepoFn: newRemoteNameToRepoFn(remotesFn), + } +} + +// ResolvePullRequests takes a base repository, a base branch name and a local branch name and uses the git configuration to +// determine the head repository and remote branch name. If we were unable to determine this from git, we default the head +// repository to the base repository. +func (r *PullRequestFindRefsResolver) ResolvePullRequestRefs(baseRepo ghrepo.Interface, baseBranchName, localBranchName string) (PRFindRefs, error) { + if baseRepo == nil { + return PRFindRefs{}, fmt.Errorf("find pull request ref resolution cannot be performed without a base repository") + } + + if localBranchName == "" { + return PRFindRefs{}, fmt.Errorf("find pull request ref resolution cannot be performed without a local branch name") + } + + headPRRef, err := TryDetermineDefaultPRHead(r.GitConfigClient, remoteToRepoResolver{r.RemoteNameToRepoFn}, localBranchName) + if err != nil { + return PRFindRefs{}, err + } + + // If the headRepo was resolved, we can just convert the response + // to refs and return it. + if headRepo, present := headPRRef.Repo.Value(); present { + qualifiedHeadRef := NewQualifiedHeadRefWithoutOwner(headPRRef.BranchName) + if !ghrepo.IsSame(headRepo, baseRepo) { + qualifiedHeadRef = NewQualifiedHeadRef(headRepo.RepoOwner(), headPRRef.BranchName) + } + + return PRFindRefs{ + qualifiedHeadRef: qualifiedHeadRef, + baseRepo: baseRepo, + baseBranchName: o.SomeIfNonZero(baseBranchName), + }, nil + } + + // If we didn't find a head repo, default to the base repo + return PRFindRefs{ + qualifiedHeadRef: NewQualifiedHeadRefWithoutOwner(headPRRef.BranchName), + baseRepo: baseRepo, + baseBranchName: o.SomeIfNonZero(baseBranchName), + }, nil +} + +// DefaultPRHead is a neighbour to defaultPushTarget, but instead of holding +// basic git remote information, it holds a resolved repository in `gh` terms. +// +// Since we may not be able to determine a default remote for a branch, this +// is also true of the resolved repository. +type DefaultPRHead struct { + Repo o.Option[ghrepo.Interface] + BranchName string +} + +// TryDetermineDefaultPRHead is a thin wrapper around determineDefaultPushTarget, which attempts to convert +// a present remote into a resolved repository. If the remote is not present, we indicate that to the caller +// by returning a None value for the repo. +func TryDetermineDefaultPRHead(gitClient GitConfigClient, remoteToRepo remoteToRepoResolver, branch string) (DefaultPRHead, error) { + pushTarget, err := tryDetermineDefaultPushTarget(gitClient, branch) + if err != nil { + return DefaultPRHead{}, err + } + + // If we have no remote, let the caller decide what to do by indicating that with a None. + if pushTarget.remote.IsNone() { + return DefaultPRHead{ + Repo: o.None[ghrepo.Interface](), + BranchName: pushTarget.branchName, + }, nil + } + + repo, err := remoteToRepo.resolve(pushTarget.remote.Unwrap()) + if err != nil { + return DefaultPRHead{}, err + } + + return DefaultPRHead{ + Repo: o.Some(repo), + BranchName: pushTarget.branchName, + }, nil +} + +// remote represents the value of the remote key in a branch's git configuration. +// This value may be a name or a URL, both of which are strings, but are unfortunately +// parsed by ReadBranchConfig into separate fields, allowing for illegal states to be +// created by accident. This is an attempt to indicate that they are mutally exclusive. +type remote interface{ sealedRemote() } + +type remoteName struct{ name string } + +func (rn remoteName) sealedRemote() {} + +type remoteURL struct{ url *url.URL } + +func (ru remoteURL) sealedRemote() {} + +// newRemoteNameToRepoFn takes a function that returns a list of remotes and +// returns a function that takes a remote name and returns the corresponding +// repository. It is a convenience function to call sites having to duplicate +// the same logic. +func newRemoteNameToRepoFn(remotesFn func() (ghContext.Remotes, error)) RemoteNameToRepoFn { + return func(remoteName string) (ghrepo.Interface, error) { + remotes, err := remotesFn() + if err != nil { + return nil, err + } + repo, err := remotes.FindByName(remoteName) + if err != nil { + return nil, err + } + return repo, nil + } +} + +// remoteToRepoResolver provides a utility method to resolve a remote (either name or URL) +// to a repo (ghrepo.Interface). +type remoteToRepoResolver struct { + remoteNameToRepo RemoteNameToRepoFn +} + +func NewRemoteToRepoResolver(remotesFn func() (ghContext.Remotes, error)) remoteToRepoResolver { + return remoteToRepoResolver{ + remoteNameToRepo: newRemoteNameToRepoFn(remotesFn), + } +} + +// resolve takes a remote and returns a repository representing it. +func (r remoteToRepoResolver) resolve(remote remote) (ghrepo.Interface, error) { + switch v := remote.(type) { + case remoteName: + repo, err := r.remoteNameToRepo(v.name) + if err != nil { + return nil, fmt.Errorf("could not resolve remote %q: %w", v.name, err) + } + return repo, nil + case remoteURL: + repo, err := ghrepo.FromURL(v.url) + if err != nil { + return nil, fmt.Errorf("could not parse remote URL %q: %w", v.url, err) + } + return repo, nil + default: + return nil, fmt.Errorf("unsupported remote type %T, value: %v", v, remote) + } +} + +// A defaultPushTarget represents the remote name or URL and a branch name +// that we would expect a branch to be pushed to if `git push` were run with +// no further arguments. This is the most likely place for the head of the PR +// to be, but it's not guaranteed. The user may have pushed to another branch +// directly via `git push :` and not set up tracking information. +// A branch name is always present. +// +// It's possible that we're unable to determine a remote, if the user had pushed directly +// to a URL for example `git push `, which is why it is optional. When present, +// the remote may either be a name or a URL. +type defaultPushTarget struct { + remote o.Option[remote] + branchName string +} + +// newDefaultPushTarget is a thin wrapper over defaultPushTarget to help with +// generic type inference, to reduce verbosity in repeating the parametric type. +func newDefaultPushTarget(remote remote, branchName string) defaultPushTarget { + return defaultPushTarget{ + remote: o.Some(remote), + branchName: branchName, + } +} + +// tryDetermineDefaultPushTarget uses git configuration to make a best guess about where a branch +// is pushed to, and where it would be pushed to if the user ran `git push` with no additional +// arguments. +// +// Firstly, it attempts to resolve the @{push} ref, which is the most reliable method, as this +// is what git uses to determine the remote tracking branch +// +// If this fails, we go through a series of steps to determine the remote: +// +// 1. check branch configuration for `branch..pushRemote = | ` +// 2. check remote configuration for `remote.pushDefault = ` +// 3. check branch configuration for `branch..remote = | ` +// +// If none of these are set, we indicate that we were unable to determine the +// remote by returning a None value for the remote. +// +// The branch name is always set. The default configuration for push.default (current) indicates +// that a git push should use the same remote branch name as the local branch name. If push.default +// is set to upstream or tracking (deprecated form of upstream), then we use the branch name from the merge ref. +func tryDetermineDefaultPushTarget(gitClient GitConfigClient, localBranchName string) (defaultPushTarget, error) { + // If @{push} resolves, then we have the remote tracking branch already, no problem. + if pushRevisionRef, err := gitClient.PushRevision(context.Background(), localBranchName); err == nil { + return newDefaultPushTarget(remoteName{pushRevisionRef.Remote}, pushRevisionRef.Branch), nil + } + + // But it doesn't always resolve, so we can suppress the error and move on to other means + // of determination. We'll first look at branch and remote configuration to make a determination. + branchConfig, err := gitClient.ReadBranchConfig(context.Background(), localBranchName) + if err != nil { + return defaultPushTarget{}, err + } + + pushDefault, err := gitClient.PushDefault(context.Background()) + if err != nil { + return defaultPushTarget{}, err + } + + // We assume the PR's branch name is the same as whatever was provided, unless the user has specified + // push.default = upstream or tracking, then we use the branch name from the merge ref. + remoteBranch := localBranchName + if pushDefault == git.PushDefaultUpstream || pushDefault == git.PushDefaultTracking { + remoteBranch = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/") + if remoteBranch == "" { + return defaultPushTarget{}, fmt.Errorf("could not determine remote branch name") + } + } + + // To get the remote, we look to the git config. It comes from one of the following, in order of precedence: + // 1. branch..pushRemote (which may be a name or a URL) + // 2. remote.pushDefault (which is a remote name) + // 3. branch..remote (which may be a name or a URL) + if branchConfig.PushRemoteName != "" { + return newDefaultPushTarget( + remoteName{branchConfig.PushRemoteName}, + remoteBranch, + ), nil + } + + if branchConfig.PushRemoteURL != nil { + return newDefaultPushTarget( + remoteURL{branchConfig.PushRemoteURL}, + remoteBranch, + ), nil + } + + remotePushDefault, err := gitClient.RemotePushDefault(context.Background()) + if err != nil { + return defaultPushTarget{}, err + } + + if remotePushDefault != "" { + return newDefaultPushTarget( + remoteName{remotePushDefault}, + remoteBranch, + ), nil + } + + if branchConfig.RemoteName != "" { + return newDefaultPushTarget( + remoteName{branchConfig.RemoteName}, + remoteBranch, + ), nil + } + + if branchConfig.RemoteURL != nil { + return newDefaultPushTarget( + remoteURL{branchConfig.RemoteURL}, + remoteBranch, + ), nil + } + + // If we couldn't find the remote, we'll indicate that to the caller via None. + return defaultPushTarget{ + remote: o.None[remote](), + branchName: remoteBranch, + }, nil +} diff --git a/pkg/cmd/pr/shared/find_refs_resolution_test.go b/pkg/cmd/pr/shared/find_refs_resolution_test.go new file mode 100644 index 00000000000..8cbb62146e3 --- /dev/null +++ b/pkg/cmd/pr/shared/find_refs_resolution_test.go @@ -0,0 +1,508 @@ +package shared + +import ( + "errors" + "net/url" + "testing" + + ghContext "github.com/cli/cli/v2/context" + "github.com/cli/cli/v2/git" + "github.com/cli/cli/v2/internal/ghrepo" + o "github.com/cli/cli/v2/pkg/option" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestQualifiedHeadRef(t *testing.T) { + t.Parallel() + + testCases := []struct { + behavior string + ref string + expectedString string + expectedBranchName string + expectedError error + }{ + { + behavior: "when a branch is provided, the parsed qualified head ref only has a branch", + ref: "feature-branch", + expectedString: "feature-branch", + expectedBranchName: "feature-branch", + }, + { + behavior: "when an owner and branch are provided, the parsed qualified head ref has both", + ref: "owner:feature-branch", + expectedString: "owner:feature-branch", + expectedBranchName: "feature-branch", + }, + { + behavior: "when the structure cannot be interpreted correctly, an error is returned", + ref: "owner:feature-branch:extra", + expectedError: errors.New("invalid qualified head ref format 'owner:feature-branch:extra'"), + }, + } + + for _, tc := range testCases { + t.Run(tc.behavior, func(t *testing.T) { + t.Parallel() + + qualifiedHeadRef, err := ParseQualifiedHeadRef(tc.ref) + if tc.expectedError != nil { + require.Equal(t, tc.expectedError, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tc.expectedString, qualifiedHeadRef.String()) + assert.Equal(t, tc.expectedBranchName, qualifiedHeadRef.BranchName()) + }) + } +} + +func TestPRFindRefs(t *testing.T) { + t.Parallel() + + t.Run("qualified head ref with owner", func(t *testing.T) { + t.Parallel() + + refs := PRFindRefs{ + qualifiedHeadRef: mustParseQualifiedHeadRef("forkowner:feature-branch"), + } + + require.Equal(t, "forkowner:feature-branch", refs.QualifiedHeadRef()) + require.Equal(t, "feature-branch", refs.UnqualifiedHeadRef()) + }) + + t.Run("qualified head ref without owner", func(t *testing.T) { + t.Parallel() + + refs := PRFindRefs{ + qualifiedHeadRef: mustParseQualifiedHeadRef("feature-branch"), + } + + require.Equal(t, "feature-branch", refs.QualifiedHeadRef()) + require.Equal(t, "feature-branch", refs.UnqualifiedHeadRef()) + }) + + t.Run("base repo", func(t *testing.T) { + t.Parallel() + + refs := PRFindRefs{ + baseRepo: ghrepo.New("owner", "repo"), + } + + require.True(t, ghrepo.IsSame(refs.BaseRepo(), ghrepo.New("owner", "repo")), "expected repos to be the same") + }) + + t.Run("matches", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + behavior string + refs PRFindRefs + baseBranchName string + qualifiedHeadRef string + expectedMatch bool + }{ + { + behavior: "when qualified head refs don't match, returns false", + refs: PRFindRefs{ + qualifiedHeadRef: mustParseQualifiedHeadRef("owner:feature-branch"), + }, + baseBranchName: "feature-branch", + qualifiedHeadRef: "feature-branch", + expectedMatch: false, + }, + { + behavior: "when base branches don't match, returns false", + refs: PRFindRefs{ + qualifiedHeadRef: mustParseQualifiedHeadRef("feature-branch"), + baseBranchName: o.Some("not-main"), + }, + baseBranchName: "main", + qualifiedHeadRef: "feature-branch", + expectedMatch: false, + }, + { + behavior: "when head refs match and there is no base branch, returns true", + refs: PRFindRefs{ + qualifiedHeadRef: mustParseQualifiedHeadRef("feature-branch"), + baseBranchName: o.None[string](), + }, + baseBranchName: "main", + qualifiedHeadRef: "feature-branch", + expectedMatch: true, + }, + { + behavior: "when head refs match and base branches match, returns true", + refs: PRFindRefs{ + qualifiedHeadRef: mustParseQualifiedHeadRef("feature-branch"), + baseBranchName: o.Some("main"), + }, + baseBranchName: "main", + qualifiedHeadRef: "feature-branch", + expectedMatch: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.behavior, func(t *testing.T) { + t.Parallel() + + require.Equal(t, tc.expectedMatch, tc.refs.Matches(tc.baseBranchName, tc.qualifiedHeadRef)) + }) + } + }) +} + +func TestPullRequestResolution(t *testing.T) { + t.Parallel() + + baseRepo := ghrepo.New("owner", "repo") + baseRemote := ghContext.Remote{ + Remote: &git.Remote{ + Name: "upstream", + }, + Repo: ghrepo.New("owner", "repo"), + } + + forkRemote := ghContext.Remote{ + Remote: &git.Remote{ + Name: "origin", + }, + Repo: ghrepo.New("otherowner", "repo-fork"), + } + + t.Run("when the base repo is nil, returns an error", func(t *testing.T) { + t.Parallel() + + resolver := NewPullRequestFindRefsResolver(stubGitConfigClient{}, dummyRemotesFn) + _, err := resolver.ResolvePullRequestRefs(nil, "", "") + require.Error(t, err) + }) + + t.Run("when the local branch name is empty, returns an error", func(t *testing.T) { + t.Parallel() + + resolver := NewPullRequestFindRefsResolver(stubGitConfigClient{}, dummyRemotesFn) + _, err := resolver.ResolvePullRequestRefs(baseRepo, "", "") + require.Error(t, err) + }) + + t.Run("when the default pr head has a repo, it is used for the refs", func(t *testing.T) { + t.Parallel() + + // Push revision is the first thing checked for resolution, + // so nothing else needs to be stubbed. + repoResolvedFromPushRevisionClient := stubGitConfigClient{ + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{ + Remote: "origin", + Branch: "feature-branch", + }, nil), + } + + resolver := NewPullRequestFindRefsResolver( + repoResolvedFromPushRevisionClient, + stubRemotes(ghContext.Remotes{&baseRemote, &forkRemote}, nil), + ) + + refs, err := resolver.ResolvePullRequestRefs(baseRepo, "main", "feature-branch") + require.NoError(t, err) + + expectedRefs := PRFindRefs{ + qualifiedHeadRef: QualifiedHeadRef{ + owner: o.Some("otherowner"), + branchName: "feature-branch", + }, + baseRepo: baseRepo, + baseBranchName: o.Some("main"), + } + + require.Equal(t, expectedRefs, refs) + }) + + t.Run("when the default pr head does not have a repo, we use the base repo for the head", func(t *testing.T) { + t.Parallel() + + // All the values stubbed here result in being unable to resolve a default repo. + noRepoResolutionStubClient := stubGitConfigClient{ + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("test error")), + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault("", nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + } + + resolver := NewPullRequestFindRefsResolver( + noRepoResolutionStubClient, + stubRemotes(ghContext.Remotes{&baseRemote, &forkRemote}, nil), + ) + + refs, err := resolver.ResolvePullRequestRefs(baseRepo, "main", "feature-branch") + require.NoError(t, err) + + expectedRefs := PRFindRefs{ + qualifiedHeadRef: QualifiedHeadRef{ + owner: o.None[string](), + branchName: "feature-branch", + }, + baseRepo: baseRepo, + baseBranchName: o.Some("main"), + } + require.Equal(t, expectedRefs, refs) + }) +} + +func TestTryDetermineDefaultPRHead(t *testing.T) { + t.Parallel() + + baseRepo := ghrepo.New("owner", "repo") + baseRemote := ghContext.Remote{ + Remote: &git.Remote{ + Name: "upstream", + }, + Repo: baseRepo, + } + + forkRepo := ghrepo.New("otherowner", "repo-fork") + forkRemote := ghContext.Remote{ + Remote: &git.Remote{ + Name: "origin", + }, + Repo: forkRepo, + } + forkRepoURL, err := url.Parse("https://github.com/otherowner/repo-fork.git") + require.NoError(t, err) + + t.Run("when the push revision is set, use that", func(t *testing.T) { + t.Parallel() + + repoResolvedFromPushRevisionClient := stubGitConfigClient{ + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{ + Remote: "origin", + Branch: "remote-feature-branch", + }, nil), + } + + defaultPRHead, err := TryDetermineDefaultPRHead( + repoResolvedFromPushRevisionClient, + stubRemoteToRepoResolver(ghContext.Remotes{&baseRemote, &forkRemote}, nil), + "feature-branch", + ) + require.NoError(t, err) + + require.True(t, ghrepo.IsSame(defaultPRHead.Repo.Unwrap(), forkRepo), "expected repos to be the same") + require.Equal(t, "remote-feature-branch", defaultPRHead.BranchName) + }) + + t.Run("when the branch config push remote is set to a name, use that", func(t *testing.T) { + t.Parallel() + + repoResolvedFromPushRemoteClient := stubGitConfigClient{ + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("no push revision")), + readBranchConfigFn: stubBranchConfig(git.BranchConfig{ + PushRemoteName: "origin", + }, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultCurrent, nil), + } + + defaultPRHead, err := TryDetermineDefaultPRHead( + repoResolvedFromPushRemoteClient, + stubRemoteToRepoResolver(ghContext.Remotes{&baseRemote, &forkRemote}, nil), + "feature-branch", + ) + require.NoError(t, err) + + require.True(t, ghrepo.IsSame(defaultPRHead.Repo.Unwrap(), forkRepo), "expected repos to be the same") + require.Equal(t, "feature-branch", defaultPRHead.BranchName) + }) + + t.Run("when the branch config push remote is set to a URL, use that", func(t *testing.T) { + t.Parallel() + + repoResolvedFromPushRemoteClient := stubGitConfigClient{ + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("no push revision")), + readBranchConfigFn: stubBranchConfig(git.BranchConfig{ + PushRemoteURL: forkRepoURL, + }, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultCurrent, nil), + } + + defaultPRHead, err := TryDetermineDefaultPRHead( + repoResolvedFromPushRemoteClient, + dummyRemoteToRepoResolver(), + "feature-branch", + ) + require.NoError(t, err) + + require.True(t, ghrepo.IsSame(defaultPRHead.Repo.Unwrap(), forkRepo), "expected repos to be the same") + require.Equal(t, "feature-branch", defaultPRHead.BranchName) + }) + + t.Run("when a remote push default is set, use that", func(t *testing.T) { + t.Parallel() + + repoResolvedFromPushRemoteClient := stubGitConfigClient{ + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("no push revision")), + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultCurrent, nil), + remotePushDefaultFn: stubRemotePushDefault("origin", nil), + } + + defaultPRHead, err := TryDetermineDefaultPRHead( + repoResolvedFromPushRemoteClient, + stubRemoteToRepoResolver(ghContext.Remotes{&baseRemote, &forkRemote}, nil), + "feature-branch", + ) + require.NoError(t, err) + + require.True(t, ghrepo.IsSame(defaultPRHead.Repo.Unwrap(), forkRepo), "expected repos to be the same") + require.Equal(t, "feature-branch", defaultPRHead.BranchName) + }) + + t.Run("when the branch config remote is set to a name, use that", func(t *testing.T) { + t.Parallel() + + repoResolvedFromPushRemoteClient := stubGitConfigClient{ + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("no push revision")), + readBranchConfigFn: stubBranchConfig(git.BranchConfig{ + RemoteName: "origin", + }, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultCurrent, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + } + + defaultPRHead, err := TryDetermineDefaultPRHead( + repoResolvedFromPushRemoteClient, + stubRemoteToRepoResolver(ghContext.Remotes{&baseRemote, &forkRemote}, nil), + "feature-branch", + ) + require.NoError(t, err) + + require.True(t, ghrepo.IsSame(defaultPRHead.Repo.Unwrap(), forkRepo), "expected repos to be the same") + require.Equal(t, "feature-branch", defaultPRHead.BranchName) + }) + + t.Run("when the branch config remote is set to a URL, use that", func(t *testing.T) { + t.Parallel() + + repoResolvedFromPushRemoteClient := stubGitConfigClient{ + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("no push revision")), + readBranchConfigFn: stubBranchConfig(git.BranchConfig{ + RemoteURL: forkRepoURL, + }, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultCurrent, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + } + + defaultPRHead, err := TryDetermineDefaultPRHead( + repoResolvedFromPushRemoteClient, + dummyRemoteToRepoResolver(), + "feature-branch", + ) + require.NoError(t, err) + + require.True(t, ghrepo.IsSame(defaultPRHead.Repo.Unwrap(), forkRepo), "expected repos to be the same") + require.Equal(t, "feature-branch", defaultPRHead.BranchName) + }) + + t.Run("when git didn't provide the necessary information, return none for the remote", func(t *testing.T) { + t.Parallel() + + // All the values stubbed here result in being unable to resolve a default repo. + noRepoResolutionStubClient := stubGitConfigClient{ + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("test error")), + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault("", nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + } + + defaultPRHead, err := TryDetermineDefaultPRHead( + noRepoResolutionStubClient, + stubRemoteToRepoResolver(ghContext.Remotes{&baseRemote, &forkRemote}, nil), + "feature-branch", + ) + require.NoError(t, err) + + require.True(t, defaultPRHead.Repo.IsNone(), "expected repo to be none") + require.Equal(t, "feature-branch", defaultPRHead.BranchName) + }) + + t.Run("when the push default is tracking or upstream, use the merge ref", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + pushDefault git.PushDefault + }{ + {pushDefault: git.PushDefaultTracking}, + {pushDefault: git.PushDefaultUpstream}, + } + + for _, tc := range testCases { + t.Run(string(tc.pushDefault), func(t *testing.T) { + t.Parallel() + + repoResolvedFromPushRemoteClient := stubGitConfigClient{ + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("test error")), + readBranchConfigFn: stubBranchConfig(git.BranchConfig{ + PushRemoteName: "origin", + MergeRef: "main", + }, nil), + pushDefaultFn: stubPushDefault(tc.pushDefault, nil), + } + + defaultPRHead, err := TryDetermineDefaultPRHead( + repoResolvedFromPushRemoteClient, + stubRemoteToRepoResolver(ghContext.Remotes{&baseRemote, &forkRemote}, nil), + "feature-branch", + ) + require.NoError(t, err) + + require.True(t, ghrepo.IsSame(defaultPRHead.Repo.Unwrap(), forkRepo), "expected repos to be the same") + require.Equal(t, "main", defaultPRHead.BranchName) + }) + } + + t.Run("but if the merge ref is empty, error", func(t *testing.T) { + t.Parallel() + + repoResolvedFromPushRemoteClient := stubGitConfigClient{ + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("test error")), + readBranchConfigFn: stubBranchConfig(git.BranchConfig{ + PushRemoteName: "origin", + MergeRef: "", // intentionally empty + }, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultUpstream, nil), + } + + _, err := TryDetermineDefaultPRHead( + repoResolvedFromPushRemoteClient, + stubRemoteToRepoResolver(ghContext.Remotes{&baseRemote, &forkRemote}, nil), + "feature-branch", + ) + require.Error(t, err) + }) + }) + +} + +func dummyRemotesFn() (ghContext.Remotes, error) { + panic("remotes fn not implemented") +} + +func dummyRemoteToRepoResolver() remoteToRepoResolver { + return NewRemoteToRepoResolver(dummyRemotesFn) +} + +func stubRemoteToRepoResolver(remotes ghContext.Remotes, err error) remoteToRepoResolver { + return NewRemoteToRepoResolver(func() (ghContext.Remotes, error) { + return remotes, err + }) +} + +func mustParseQualifiedHeadRef(ref string) QualifiedHeadRef { + parsed, err := ParseQualifiedHeadRef(ref) + if err != nil { + panic(err) + } + return parsed +} diff --git a/pkg/cmd/pr/shared/finder.go b/pkg/cmd/pr/shared/finder.go index a5452852788..6d36ef816e4 100644 --- a/pkg/cmd/pr/shared/finder.go +++ b/pkg/cmd/pr/shared/finder.go @@ -13,11 +13,12 @@ import ( "time" "github.com/cli/cli/v2/api" - remotes "github.com/cli/cli/v2/context" + ghContext "github.com/cli/cli/v2/context" "github.com/cli/cli/v2/git" fd "github.com/cli/cli/v2/internal/featuredetection" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/pkg/cmdutil" + o "github.com/cli/cli/v2/pkg/option" "github.com/cli/cli/v2/pkg/set" "github.com/shurcooL/githubv4" "golang.org/x/sync/errgroup" @@ -32,16 +33,20 @@ type progressIndicator interface { StopProgressIndicator() } +type GitConfigClient interface { + ReadBranchConfig(ctx context.Context, branchName string) (git.BranchConfig, error) + PushDefault(ctx context.Context) (git.PushDefault, error) + RemotePushDefault(ctx context.Context) (string, error) + PushRevision(ctx context.Context, branchName string) (git.RemoteTrackingRef, error) +} + type finder struct { - baseRepoFn func() (ghrepo.Interface, error) - branchFn func() (string, error) - remotesFn func() (remotes.Remotes, error) - httpClient func() (*http.Client, error) - pushDefault func() (string, error) - remotePushDefault func() (string, error) - parsePushRevision func(string) (string, error) - branchConfig func(string) (git.BranchConfig, error) - progress progressIndicator + baseRepoFn func() (ghrepo.Interface, error) + branchFn func() (string, error) + httpClient func() (*http.Client, error) + remotesFn func() (ghContext.Remotes, error) + gitConfigClient GitConfigClient + progress progressIndicator baseRefRepo ghrepo.Interface prNumber int @@ -56,23 +61,12 @@ func NewFinder(factory *cmdutil.Factory) PRFinder { } return &finder{ - baseRepoFn: factory.BaseRepo, - branchFn: factory.Branch, - remotesFn: factory.Remotes, - httpClient: factory.HttpClient, - pushDefault: func() (string, error) { - return factory.GitClient.PushDefault(context.Background()) - }, - remotePushDefault: func() (string, error) { - return factory.GitClient.RemotePushDefault(context.Background()) - }, - parsePushRevision: func(branch string) (string, error) { - return factory.GitClient.ParsePushRevision(context.Background(), branch) - }, - progress: factory.IOStreams, - branchConfig: func(s string) (git.BranchConfig, error) { - return factory.GitClient.ReadBranchConfig(context.Background(), s) - }, + baseRepoFn: factory.BaseRepo, + branchFn: factory.Branch, + httpClient: factory.HttpClient, + gitConfigClient: factory.GitClient, + remotesFn: factory.Remotes, + progress: factory.IOStreams, } } @@ -97,28 +91,6 @@ type FindOptions struct { States []string } -// TODO: Does this also need the BaseBranchName? -// PR's are represented by the following: -// baseRef -----PR-----> headRef -// -// A ref is described as "remoteName/branchName", so -// baseRepoName/baseBranchName -----PR-----> headRepoName/headBranchName -type PullRequestRefs struct { - BranchName string - HeadRepo ghrepo.Interface - BaseRepo ghrepo.Interface -} - -// GetPRHeadLabel returns the string that the GitHub API uses to identify the PR. This is -// either just the branch name or, if the PR is originating from a fork, the fork owner -// and the branch name, like :. -func (s *PullRequestRefs) GetPRHeadLabel() string { - if ghrepo.IsSame(s.HeadRepo, s.BaseRepo) { - return s.BranchName - } - return fmt.Sprintf("%s:%s", s.HeadRepo.RepoOwner(), s.BranchName) -} - func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) { // If we have a URL, we don't need git stuff if len(opts.Fields) == 0 { @@ -138,7 +110,7 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err f.baseRefRepo = repo } - var prRefs PullRequestRefs + var prRefs PRFindRefs if opts.Selector == "" { // You must be in a git repo for this case to work currentBranchName, err := f.branchFn() @@ -148,7 +120,7 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err f.branchName = currentBranchName // Get the branch config for the current branchName - branchConfig, err := f.branchConfig(f.branchName) + branchConfig, err := f.gitConfigClient.ReadBranchConfig(context.Background(), f.branchName) if err != nil { return nil, nil, err } @@ -162,30 +134,19 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err // Determine the PullRequestRefs from config if f.prNumber == 0 { - rems, err := f.remotesFn() - if err != nil { - return nil, nil, err - } - - // Suppressing these errors as we have other means of computing the PullRequestRefs when these fail. - parsedPushRevision, _ := f.parsePushRevision(f.branchName) - - pushDefault, err := f.pushDefault() - if err != nil { - return nil, nil, err - } - - remotePushDefault, err := f.remotePushDefault() - if err != nil { - return nil, nil, err - } - - prRefs, err = ParsePRRefs(f.branchName, branchConfig, parsedPushRevision, pushDefault, remotePushDefault, f.baseRefRepo, rems) + prRefsResolver := NewPullRequestFindRefsResolver( + // We requested the branch config already, so let's cache that + CachedBranchConfigGitConfigClient{ + CachedBranchConfig: branchConfig, + GitConfigClient: f.gitConfigClient, + }, + f.remotesFn, + ) + prRefs, err = prRefsResolver.ResolvePullRequestRefs(f.baseRefRepo, opts.BaseBranch, f.branchName) if err != nil { return nil, nil, err } } - } else if f.prNumber == 0 { // You gave me a selector but I couldn't find a PR number (it wasn't a URL) @@ -200,11 +161,17 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err f.prNumber = prNumber } else { f.branchName = opts.Selector - // We don't expect an error here because parsedPushRevision is empty - prRefs, err = ParsePRRefs(f.branchName, git.BranchConfig{}, "", "", "", f.baseRefRepo, remotes.Remotes{}) + + qualifiedHeadRef, err := ParseQualifiedHeadRef(f.branchName) if err != nil { return nil, nil, err } + + prRefs = PRFindRefs{ + qualifiedHeadRef: qualifiedHeadRef, + baseRepo: f.baseRefRepo, + baseBranchName: o.SomeIfNonZero(opts.BaseBranch), + } } } @@ -255,7 +222,7 @@ func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, err return pr, f.baseRefRepo, err } } else { - pr, err = findForBranch(httpClient, f.baseRefRepo, opts.BaseBranch, prRefs.GetPRHeadLabel(), opts.States, fields.ToSlice()) + pr, err = findForRefs(httpClient, prRefs, opts.States, fields.ToSlice()) if err != nil { return pr, f.baseRefRepo, err } @@ -317,72 +284,6 @@ func (f *finder) parseURL(prURL string) (ghrepo.Interface, int, error) { return repo, prNumber, nil } -func ParsePRRefs(currentBranchName string, branchConfig git.BranchConfig, parsedPushRevision string, pushDefault string, remotePushDefault string, baseRefRepo ghrepo.Interface, rems remotes.Remotes) (PullRequestRefs, error) { - prRefs := PullRequestRefs{ - BaseRepo: baseRefRepo, - } - - // If @{push} resolves, then we have all the information we need to determine the head repo - // and branch name. It is of the form /. - if parsedPushRevision != "" { - for _, r := range rems { - // Find the remote who's name matches the push prefix - if strings.HasPrefix(parsedPushRevision, r.Name+"/") { - prRefs.BranchName = strings.TrimPrefix(parsedPushRevision, r.Name+"/") - prRefs.HeadRepo = r.Repo - return prRefs, nil - } - } - - remoteNames := make([]string, len(rems)) - for i, r := range rems { - remoteNames[i] = r.Name - } - return PullRequestRefs{}, fmt.Errorf("no remote for %q found in %q", parsedPushRevision, strings.Join(remoteNames, ", ")) - } - - // We assume the PR's branch name is the same as whatever f.BranchFn() returned earlier - // unless the user has specified push.default = upstream or tracking, then we use the - // branch name from the merge ref. - prRefs.BranchName = currentBranchName - if pushDefault == "upstream" || pushDefault == "tracking" { - prRefs.BranchName = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/") - } - - // To get the HeadRepo, we look to the git config. The HeadRepo comes from one of the following, in order of precedence: - // 1. branch..pushRemote - // 2. remote.pushDefault - // 3. branch..remote - if branchConfig.PushRemoteName != "" { - if r, err := rems.FindByName(branchConfig.PushRemoteName); err == nil { - prRefs.HeadRepo = r.Repo - } - } else if branchConfig.PushRemoteURL != nil { - if r, err := ghrepo.FromURL(branchConfig.PushRemoteURL); err == nil { - prRefs.HeadRepo = r - } - } else if remotePushDefault != "" { - if r, err := rems.FindByName(remotePushDefault); err == nil { - prRefs.HeadRepo = r.Repo - } - } else if branchConfig.RemoteName != "" { - if r, err := rems.FindByName(branchConfig.RemoteName); err == nil { - prRefs.HeadRepo = r.Repo - } - } else if branchConfig.RemoteURL != nil { - if r, err := ghrepo.FromURL(branchConfig.RemoteURL); err == nil { - prRefs.HeadRepo = r - } - } - - // The PR merges from a branch in the same repo as the base branch (usually the default branch) - if prRefs.HeadRepo == nil { - prRefs.HeadRepo = baseRefRepo - } - - return prRefs, nil -} - func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fields []string) (*api.PullRequest, error) { type response struct { Repository struct { @@ -413,7 +314,7 @@ func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fi return &resp.Repository.PullRequest, nil } -func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, headBranchWithOwnerIfFork string, stateFilters, fields []string) (*api.PullRequest, error) { +func findForRefs(httpClient *http.Client, prRefs PRFindRefs, stateFilters, fields []string) (*api.PullRequest, error) { type response struct { Repository struct { PullRequests struct { @@ -440,21 +341,16 @@ func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, h } }`, api.PullRequestGraphQL(fieldSet.ToSlice())) - branchWithoutOwner := headBranchWithOwnerIfFork - if idx := strings.Index(headBranchWithOwnerIfFork, ":"); idx >= 0 { - branchWithoutOwner = headBranchWithOwnerIfFork[idx+1:] - } - variables := map[string]interface{}{ - "owner": repo.RepoOwner(), - "repo": repo.RepoName(), - "headRefName": branchWithoutOwner, + "owner": prRefs.BaseRepo().RepoOwner(), + "repo": prRefs.BaseRepo().RepoName(), + "headRefName": prRefs.UnqualifiedHeadRef(), "states": stateFilters, } var resp response client := api.NewClientFromHTTP(httpClient) - err := client.GraphQL(repo.RepoHost(), query, variables, &resp) + err := client.GraphQL(prRefs.BaseRepo().RepoHost(), query, variables, &resp) if err != nil { return nil, err } @@ -465,17 +361,15 @@ func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, h }) for _, pr := range prs { - headBranchMatches := pr.HeadLabel() == headBranchWithOwnerIfFork - baseBranchEmptyOrMatches := baseBranch == "" || pr.BaseRefName == baseBranch // When the head is the default branch, it doesn't really make sense to show merged or closed PRs. // https://github.com/cli/cli/issues/4263 - isNotClosedOrMergedWhenHeadIsDefault := pr.State == "OPEN" || resp.Repository.DefaultBranchRef.Name != headBranchWithOwnerIfFork - if headBranchMatches && baseBranchEmptyOrMatches && isNotClosedOrMergedWhenHeadIsDefault { + isNotClosedOrMergedWhenHeadIsDefault := pr.State == "OPEN" || resp.Repository.DefaultBranchRef.Name != prRefs.QualifiedHeadRef() + if prRefs.Matches(pr.BaseRefName, pr.HeadLabel()) && isNotClosedOrMergedWhenHeadIsDefault { return &pr, nil } } - return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", headBranchWithOwnerIfFork)} + return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", prRefs.QualifiedHeadRef())} } func preloadPrReviews(httpClient *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error { diff --git a/pkg/cmd/pr/shared/finder_test.go b/pkg/cmd/pr/shared/finder_test.go index 36551ab4294..e1aae16b114 100644 --- a/pkg/cmd/pr/shared/finder_test.go +++ b/pkg/cmd/pr/shared/finder_test.go @@ -1,46 +1,41 @@ package shared import ( + "context" "errors" - "fmt" "net/http" "net/url" "testing" - "github.com/cli/cli/v2/context" + ghContext "github.com/cli/cli/v2/context" "github.com/cli/cli/v2/git" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/pkg/httpmock" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type args struct { - baseRepoFn func() (ghrepo.Interface, error) - branchFn func() (string, error) - branchConfig func(string) (git.BranchConfig, error) - pushDefault func() (string, error) - remotePushDefault func() (string, error) - parsePushRevision func(string) (string, error) - selector string - fields []string - baseBranch string + baseRepoFn func() (ghrepo.Interface, error) + branchFn func() (string, error) + gitConfigClient stubGitConfigClient + selector string + fields []string + baseBranch string } func TestFind(t *testing.T) { - // TODO: Abstract these out meaningfully for reuse in parsePRRefs tests originOwnerUrl, err := url.Parse("https://github.com/ORIGINOWNER/REPO.git") if err != nil { t.Fatal(err) } - remoteOrigin := context.Remote{ + remoteOrigin := ghContext.Remote{ Remote: &git.Remote{ Name: "origin", FetchURL: originOwnerUrl, }, Repo: ghrepo.New("ORIGINOWNER", "REPO"), } - remoteOther := context.Remote{ + remoteOther := ghContext.Remote{ Remote: &git.Remote{ Name: "other", FetchURL: originOwnerUrl, @@ -52,7 +47,7 @@ func TestFind(t *testing.T) { if err != nil { t.Fatal(err) } - remoteUpstream := context.Remote{ + remoteUpstream := ghContext.Remote{ Remote: &git.Remote{ Name: "upstream", FetchURL: upstreamOwnerUrl, @@ -77,7 +72,6 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -99,12 +93,14 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{ - PushRemoteName: remoteOrigin.Remote.Name, - }, nil), - pushDefault: stubPushDefault("simple", nil), - remotePushDefault: stubRemotePushDefault("", nil), - parsePushRevision: stubParsedPushRevision("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{ + PushRemoteName: remoteOrigin.Remote.Name, + }, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("testErr")), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -134,9 +130,11 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), - pushDefault: stubPushDefault("simple", nil), - remotePushDefault: stubRemotePushDefault("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + }, }, wantErr: true, }, @@ -157,9 +155,11 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), - pushDefault: stubPushDefault("simple", nil), - remotePushDefault: stubRemotePushDefault("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + }, }, httpStub: nil, wantPR: 13, @@ -174,9 +174,11 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), - pushDefault: stubPushDefault("simple", nil), - remotePushDefault: stubRemotePushDefault("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -197,9 +199,11 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), - pushDefault: stubPushDefault("simple", nil), - remotePushDefault: stubRemotePushDefault("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -223,15 +227,17 @@ func TestFind(t *testing.T) { ExitCode: 128, } }, - branchConfig: stubBranchConfig(git.BranchConfig{}, &git.GitError{ - Stderr: "fatal: branchConfig error", - ExitCode: 128, - }), - pushDefault: stubPushDefault("simple", nil), - remotePushDefault: stubRemotePushDefault("", &git.GitError{ - Stderr: "fatal: remotePushDefault error", - ExitCode: 128, - }), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, &git.GitError{ + Stderr: "fatal: branchConfig error", + ExitCode: 128, + }), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", &git.GitError{ + Stderr: "fatal: remotePushDefault error", + ExitCode: 128, + }), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -252,10 +258,12 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), - pushDefault: stubPushDefault("simple", nil), - parsePushRevision: stubParsedPushRevision("", nil), - remotePushDefault: stubRemotePushDefault("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("testErr")), + remotePushDefaultFn: stubRemotePushDefault("", nil), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -296,10 +304,12 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), - pushDefault: stubPushDefault("simple", nil), - remotePushDefault: stubRemotePushDefault("", nil), - parsePushRevision: stubParsedPushRevision("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("testErr")), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -339,10 +349,12 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), - pushDefault: stubPushDefault("simple", nil), - remotePushDefault: stubRemotePushDefault("", nil), - parsePushRevision: stubParsedPushRevision("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("testErr")), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -374,10 +386,12 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), - pushDefault: stubPushDefault("simple", nil), - remotePushDefault: stubRemotePushDefault("", nil), - parsePushRevision: stubParsedPushRevision("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("testErr")), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -423,13 +437,15 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{ - MergeRef: "refs/heads/blue-upstream-berries", - PushRemoteName: "upstream", - }, nil), - pushDefault: stubPushDefault("upstream", nil), - remotePushDefault: stubRemotePushDefault("", nil), - parsePushRevision: stubParsedPushRevision("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{ + MergeRef: "refs/heads/blue-upstream-berries", + PushRemoteName: "upstream", + }, nil), + pushDefaultFn: stubPushDefault("upstream", nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("testErr")), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -463,13 +479,15 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{ - MergeRef: "refs/heads/blue-upstream-berries", - PushRemoteURL: remoteUpstream.Remote.FetchURL, - }, nil), - pushDefault: stubPushDefault("upstream", nil), - remotePushDefault: stubRemotePushDefault("", nil), - parsePushRevision: stubParsedPushRevision("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{ + MergeRef: "refs/heads/blue-upstream-berries", + PushRemoteURL: remoteUpstream.Remote.FetchURL, + }, nil), + pushDefaultFn: stubPushDefault("upstream", nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{}, errors.New("testErr")), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -499,10 +517,12 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{}, nil), - pushDefault: stubPushDefault("simple", nil), - remotePushDefault: stubRemotePushDefault("", nil), - parsePushRevision: stubParsedPushRevision("other/blueberries", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{}, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + pushRevisionFn: stubPushRevision(git.RemoteTrackingRef{Remote: "other", Branch: "blueberries"}, nil), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -534,9 +554,11 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{ - MergeRef: "refs/pull/13/head", - }, nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{ + MergeRef: "refs/pull/13/head", + }, nil), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -559,11 +581,13 @@ func TestFind(t *testing.T) { branchFn: func() (string, error) { return "blueberries", nil }, - branchConfig: stubBranchConfig(git.BranchConfig{ - MergeRef: "refs/pull/13/head", - }, nil), - pushDefault: stubPushDefault("simple", nil), - remotePushDefault: stubRemotePushDefault("", nil), + gitConfigClient: stubGitConfigClient{ + readBranchConfigFn: stubBranchConfig(git.BranchConfig{ + MergeRef: "refs/pull/13/head", + }, nil), + pushDefaultFn: stubPushDefault(git.PushDefaultSimple, nil), + remotePushDefaultFn: stubRemotePushDefault("", nil), + }, }, httpStub: func(r *httpmock.Registry) { r.Register( @@ -575,32 +599,32 @@ func TestFind(t *testing.T) { r.Register( httpmock.GraphQL(`query PullRequestProjectItems\b`), httpmock.GraphQLQuery(`{ - "data": { - "repository": { - "pullRequest": { - "projectItems": { - "nodes": [ - { - "id": "PVTI_lADOB-vozM4AVk16zgK6U50", - "project": { - "id": "PVT_kwDOB-vozM4AVk16", - "title": "Test Project" - }, - "status": { - "optionId": "47fc9ee4", - "name": "In Progress" - } - } - ], - "pageInfo": { - "hasNextPage": false, - "endCursor": "MQ" - } - } - } - } - } - }`, + "data": { + "repository": { + "pullRequest": { + "projectItems": { + "nodes": [ + { + "id": "PVTI_lADOB-vozM4AVk16zgK6U50", + "project": { + "id": "PVT_kwDOB-vozM4AVk16", + "title": "Test Project" + }, + "status": { + "optionId": "47fc9ee4", + "name": "In Progress" + } + } + ], + "pageInfo": { + "hasNextPage": false, + "endCursor": "MQ" + } + } + } + } + } + }`, func(query string, inputs map[string]interface{}) { require.Equal(t, float64(13), inputs["number"]) require.Equal(t, "OWNER", inputs["owner"]) @@ -624,13 +648,10 @@ func TestFind(t *testing.T) { httpClient: func() (*http.Client, error) { return &http.Client{Transport: reg}, nil }, - baseRepoFn: tt.args.baseRepoFn, - branchFn: tt.args.branchFn, - branchConfig: tt.args.branchConfig, - pushDefault: tt.args.pushDefault, - remotePushDefault: tt.args.remotePushDefault, - parsePushRevision: tt.args.parsePushRevision, - remotesFn: stubRemotes(context.Remotes{ + baseRepoFn: tt.args.baseRepoFn, + branchFn: tt.args.branchFn, + gitConfigClient: tt.args.gitConfigClient, + remotesFn: stubRemotes(ghContext.Remotes{ &remoteOrigin, &remoteOther, &remoteUpstream, @@ -667,343 +688,73 @@ func TestFind(t *testing.T) { } } -func TestParsePRRefs(t *testing.T) { - originOwnerUrl, err := url.Parse("https://github.com/ORIGINOWNER/REPO.git") - if err != nil { - t.Fatal(err) - } - remoteOrigin := context.Remote{ - Remote: &git.Remote{ - Name: "origin", - FetchURL: originOwnerUrl, - }, - Repo: ghrepo.New("ORIGINOWNER", "REPO"), - } - remoteOther := context.Remote{ - Remote: &git.Remote{ - Name: "other", - FetchURL: originOwnerUrl, - }, - Repo: ghrepo.New("ORIGINOWNER", "REPO"), +func stubBranchConfig(branchConfig git.BranchConfig, err error) func(context.Context, string) (git.BranchConfig, error) { + return func(_ context.Context, branch string) (git.BranchConfig, error) { + return branchConfig, err } +} - upstreamOwnerUrl, err := url.Parse("https://github.com/UPSTREAMOWNER/REPO.git") - if err != nil { - t.Fatal(err) - } - remoteUpstream := context.Remote{ - Remote: &git.Remote{ - Name: "upstream", - FetchURL: upstreamOwnerUrl, - }, - Repo: ghrepo.New("UPSTREAMOWNER", "REPO"), +func stubRemotes(remotes ghContext.Remotes, err error) func() (ghContext.Remotes, error) { + return func() (ghContext.Remotes, error) { + return remotes, err } +} - tests := []struct { - name string - branchConfig git.BranchConfig - pushDefault string - parsedPushRevision string - remotePushDefault string - currentBranchName string - baseRefRepo ghrepo.Interface - rems context.Remotes - wantPRRefs PullRequestRefs - wantErr error - }{ - { - name: "When the branch is called 'blueberries' with an empty branch config, it returns the correct PullRequestRefs", - branchConfig: git.BranchConfig{}, - currentBranchName: "blueberries", - baseRefRepo: remoteOrigin.Repo, - wantPRRefs: PullRequestRefs{ - BranchName: "blueberries", - HeadRepo: remoteOrigin.Repo, - BaseRepo: remoteOrigin.Repo, - }, - wantErr: nil, - }, - { - name: "When the branch is called 'otherBranch' with an empty branch config, it returns the correct PullRequestRefs", - branchConfig: git.BranchConfig{}, - currentBranchName: "otherBranch", - baseRefRepo: remoteOrigin.Repo, - wantPRRefs: PullRequestRefs{ - BranchName: "otherBranch", - HeadRepo: remoteOrigin.Repo, - BaseRepo: remoteOrigin.Repo, - }, - wantErr: nil, - }, - { - name: "When the branch name doesn't match the branch name in BranchConfig.Push, it returns the BranchConfig.Push branch name", - parsedPushRevision: "origin/pushBranch", - currentBranchName: "blueberries", - baseRefRepo: remoteOrigin.Repo, - rems: context.Remotes{ - &remoteOrigin, - }, - wantPRRefs: PullRequestRefs{ - BranchName: "pushBranch", - HeadRepo: remoteOrigin.Repo, - BaseRepo: remoteOrigin.Repo, - }, - wantErr: nil, - }, - { - name: "When the push revision doesn't match a remote, it returns an error", - parsedPushRevision: "origin/differentPushBranch", - currentBranchName: "blueberries", - baseRefRepo: remoteOrigin.Repo, - rems: context.Remotes{ - &remoteUpstream, - &remoteOther, - }, - wantPRRefs: PullRequestRefs{}, - wantErr: fmt.Errorf("no remote for %q found in %q", "origin/differentPushBranch", "upstream, other"), - }, - { - name: "When the branch name doesn't match a different branch name in BranchConfig.Push and the remote isn't 'origin', it returns the BranchConfig.Push branch name", - parsedPushRevision: "other/pushBranch", - currentBranchName: "blueberries", - baseRefRepo: remoteOrigin.Repo, - rems: context.Remotes{ - &remoteOther, - }, - wantPRRefs: PullRequestRefs{ - BranchName: "pushBranch", - HeadRepo: remoteOther.Repo, - BaseRepo: remoteOrigin.Repo, - }, - wantErr: nil, - }, - { - name: "When the push remote is the same as the baseRepo, it returns the baseRepo as the PullRequestRefs HeadRepo", - branchConfig: git.BranchConfig{ - PushRemoteName: remoteOrigin.Remote.Name, - }, - currentBranchName: "blueberries", - baseRefRepo: remoteOrigin.Repo, - rems: context.Remotes{ - &remoteOrigin, - &remoteUpstream, - }, - wantPRRefs: PullRequestRefs{ - BranchName: "blueberries", - HeadRepo: remoteOrigin.Repo, - BaseRepo: remoteOrigin.Repo, - }, - wantErr: nil, - }, - { - name: "When the push remote is different from the baseRepo, it returns the push remote repo as the PullRequestRefs HeadRepo", - branchConfig: git.BranchConfig{ - PushRemoteName: remoteOrigin.Remote.Name, - }, - currentBranchName: "blueberries", - baseRefRepo: remoteUpstream.Repo, - rems: context.Remotes{ - &remoteOrigin, - &remoteUpstream, - }, - wantPRRefs: PullRequestRefs{ - BranchName: "blueberries", - HeadRepo: remoteOrigin.Repo, - BaseRepo: remoteUpstream.Repo, - }, - wantErr: nil, - }, - { - name: "When the push remote defined by a URL and the baseRepo is different from the push remote, it returns the push remote repo as the PullRequestRefs HeadRepo", - branchConfig: git.BranchConfig{ - PushRemoteURL: remoteOrigin.Remote.FetchURL, - }, - currentBranchName: "blueberries", - baseRefRepo: remoteUpstream.Repo, - rems: context.Remotes{ - &remoteOrigin, - &remoteUpstream, - }, - wantPRRefs: PullRequestRefs{ - BranchName: "blueberries", - HeadRepo: remoteOrigin.Repo, - BaseRepo: remoteUpstream.Repo, - }, - wantErr: nil, - }, - { - name: "When the push remote and merge ref are configured to a different repo and push.default = upstream, it should return the branch name from the other repo", - branchConfig: git.BranchConfig{ - PushRemoteName: remoteUpstream.Remote.Name, - MergeRef: "refs/heads/blue-upstream-berries", - }, - pushDefault: "upstream", - currentBranchName: "blueberries", - baseRefRepo: remoteOrigin.Repo, - rems: context.Remotes{ - &remoteOrigin, - &remoteUpstream, - }, - wantPRRefs: PullRequestRefs{ - BranchName: "blue-upstream-berries", - HeadRepo: remoteUpstream.Repo, - BaseRepo: remoteOrigin.Repo, - }, - wantErr: nil, - }, - { - name: "When the push remote and merge ref are configured to a different repo and push.default = tracking, it should return the branch name from the other repo", - branchConfig: git.BranchConfig{ - PushRemoteName: remoteUpstream.Remote.Name, - MergeRef: "refs/heads/blue-upstream-berries", - }, - pushDefault: "tracking", - currentBranchName: "blueberries", - baseRefRepo: remoteOrigin.Repo, - rems: context.Remotes{ - &remoteOrigin, - &remoteUpstream, - }, - wantPRRefs: PullRequestRefs{ - BranchName: "blue-upstream-berries", - HeadRepo: remoteUpstream.Repo, - BaseRepo: remoteOrigin.Repo, - }, - wantErr: nil, - }, - { - name: "When remote.pushDefault is set, it returns the correct PullRequestRefs", - branchConfig: git.BranchConfig{}, - remotePushDefault: remoteUpstream.Remote.Name, - currentBranchName: "blueberries", - baseRefRepo: remoteOrigin.Repo, - rems: context.Remotes{ - &remoteOrigin, - &remoteUpstream, - }, - wantPRRefs: PullRequestRefs{ - BranchName: "blueberries", - HeadRepo: remoteUpstream.Repo, - BaseRepo: remoteOrigin.Repo, - }, - wantErr: nil, - }, - { - name: "When the remote name is set on the branch, it returns the correct PullRequestRefs", - branchConfig: git.BranchConfig{ - RemoteName: remoteUpstream.Remote.Name, - }, - currentBranchName: "blueberries", - baseRefRepo: remoteOrigin.Repo, - rems: context.Remotes{ - &remoteOrigin, - &remoteUpstream, - }, - wantPRRefs: PullRequestRefs{ - BranchName: "blueberries", - HeadRepo: remoteUpstream.Repo, - BaseRepo: remoteOrigin.Repo, - }, - wantErr: nil, - }, - { - name: "When the remote URL is set on the branch, it returns the correct PullRequestRefs", - branchConfig: git.BranchConfig{ - RemoteURL: remoteUpstream.Remote.FetchURL, - }, - currentBranchName: "blueberries", - baseRefRepo: remoteOrigin.Repo, - rems: context.Remotes{ - &remoteOrigin, - &remoteUpstream, - }, - wantPRRefs: PullRequestRefs{ - BranchName: "blueberries", - HeadRepo: remoteUpstream.Repo, - BaseRepo: remoteOrigin.Repo, - }, - wantErr: nil, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - prRefs, err := ParsePRRefs(tt.currentBranchName, tt.branchConfig, tt.parsedPushRevision, tt.pushDefault, tt.remotePushDefault, tt.baseRefRepo, tt.rems) - if tt.wantErr != nil { - require.Equal(t, tt.wantErr, err) - } else { - require.NoError(t, err) - } - require.Equal(t, tt.wantPRRefs, prRefs) - }) +func stubBaseRepoFn(baseRepo ghrepo.Interface, err error) func() (ghrepo.Interface, error) { + return func() (ghrepo.Interface, error) { + return baseRepo, err } } -func TestPRRefs_GetPRHeadLabel(t *testing.T) { - originRepo := ghrepo.New("ORIGINOWNER", "REPO") - upstreamRepo := ghrepo.New("UPSTREAMOWNER", "REPO") - tests := []struct { - name string - prRefs PullRequestRefs - want string - }{ - { - name: "When the HeadRepo and BaseRepo match, it returns the branch name", - prRefs: PullRequestRefs{ - BranchName: "blueberries", - HeadRepo: originRepo, - BaseRepo: originRepo, - }, - want: "blueberries", - }, - { - name: "When the HeadRepo and BaseRepo do not match, it returns the prepended HeadRepo owner to the branch name", - prRefs: PullRequestRefs{ - BranchName: "blueberries", - HeadRepo: originRepo, - BaseRepo: upstreamRepo, - }, - want: "ORIGINOWNER:blueberries", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.want, tt.prRefs.GetPRHeadLabel()) - }) +func stubPushDefault(pushDefault git.PushDefault, err error) func(context.Context) (git.PushDefault, error) { + return func(_ context.Context) (git.PushDefault, error) { + return pushDefault, err } } -func stubBranchConfig(branchConfig git.BranchConfig, err error) func(string) (git.BranchConfig, error) { - return func(branch string) (git.BranchConfig, error) { - return branchConfig, err +func stubRemotePushDefault(remotePushDefault string, err error) func(context.Context) (string, error) { + return func(_ context.Context) (string, error) { + return remotePushDefault, err } } -func stubRemotes(remotes context.Remotes, err error) func() (context.Remotes, error) { - return func() (context.Remotes, error) { - return remotes, err +func stubPushRevision(parsedPushRevision git.RemoteTrackingRef, err error) func(context.Context, string) (git.RemoteTrackingRef, error) { + return func(_ context.Context, _ string) (git.RemoteTrackingRef, error) { + return parsedPushRevision, err } } -func stubBaseRepoFn(baseRepo ghrepo.Interface, err error) func() (ghrepo.Interface, error) { - return func() (ghrepo.Interface, error) { - return baseRepo, err +type stubGitConfigClient struct { + readBranchConfigFn func(ctx context.Context, branchName string) (git.BranchConfig, error) + pushDefaultFn func(ctx context.Context) (git.PushDefault, error) + remotePushDefaultFn func(ctx context.Context) (string, error) + pushRevisionFn func(ctx context.Context, branchName string) (git.RemoteTrackingRef, error) +} + +func (s stubGitConfigClient) ReadBranchConfig(ctx context.Context, branchName string) (git.BranchConfig, error) { + if s.readBranchConfigFn == nil { + panic("unexpected call to ReadBranchConfig") } + return s.readBranchConfigFn(ctx, branchName) } -func stubPushDefault(pushDefault string, err error) func() (string, error) { - return func() (string, error) { - return pushDefault, err +func (s stubGitConfigClient) PushDefault(ctx context.Context) (git.PushDefault, error) { + if s.pushDefaultFn == nil { + panic("unexpected call to PushDefault") } + return s.pushDefaultFn(ctx) } -func stubRemotePushDefault(remotePushDefault string, err error) func() (string, error) { - return func() (string, error) { - return remotePushDefault, err +func (s stubGitConfigClient) RemotePushDefault(ctx context.Context) (string, error) { + if s.remotePushDefaultFn == nil { + panic("unexpected call to RemotePushDefault") } + return s.remotePushDefaultFn(ctx) } -func stubParsedPushRevision(parsedPushRevision string, err error) func(string) (string, error) { - return func(_ string) (string, error) { - return parsedPushRevision, err +func (s stubGitConfigClient) PushRevision(ctx context.Context, branchName string) (git.RemoteTrackingRef, error) { + if s.pushRevisionFn == nil { + panic("unexpected call to PushRevision") } + return s.pushRevisionFn(ctx, branchName) } diff --git a/pkg/cmd/pr/shared/git_cached_config_client.go b/pkg/cmd/pr/shared/git_cached_config_client.go new file mode 100644 index 00000000000..aea25abeea0 --- /dev/null +++ b/pkg/cmd/pr/shared/git_cached_config_client.go @@ -0,0 +1,18 @@ +package shared + +import ( + "context" + + "github.com/cli/cli/v2/git" +) + +var _ GitConfigClient = &CachedBranchConfigGitConfigClient{} + +type CachedBranchConfigGitConfigClient struct { + CachedBranchConfig git.BranchConfig + GitConfigClient +} + +func (c CachedBranchConfigGitConfigClient) ReadBranchConfig(ctx context.Context, branchName string) (git.BranchConfig, error) { + return c.CachedBranchConfig, nil +} diff --git a/pkg/cmd/pr/shared/params.go b/pkg/cmd/pr/shared/params.go index 128c51068a0..4f36a80aaa5 100644 --- a/pkg/cmd/pr/shared/params.go +++ b/pkg/cmd/pr/shared/params.go @@ -6,12 +6,13 @@ import ( "strings" "github.com/cli/cli/v2/api" + "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/pkg/search" "github.com/google/shlex" ) -func WithPrAndIssueQueryParams(client *api.Client, baseRepo ghrepo.Interface, baseURL string, state IssueMetadataState) (string, error) { +func WithPrAndIssueQueryParams(client *api.Client, baseRepo ghrepo.Interface, baseURL string, state IssueMetadataState, projectsV1Support gh.ProjectsV1Support) (string, error) { u, err := url.Parse(baseURL) if err != nil { return "", err @@ -34,8 +35,8 @@ func WithPrAndIssueQueryParams(client *api.Client, baseRepo ghrepo.Interface, ba if len(state.Labels) > 0 { q.Set("labels", strings.Join(state.Labels, ",")) } - if len(state.Projects) > 0 { - projectPaths, err := api.ProjectNamesToPaths(client, baseRepo, state.Projects) + if len(state.ProjectTitles) > 0 { + projectPaths, err := api.ProjectNamesToPaths(client, baseRepo, state.ProjectTitles, projectsV1Support) if err != nil { return "", fmt.Errorf("could not add to project: %w", err) } @@ -56,7 +57,7 @@ func ValidURL(urlStr string) bool { // Ensure that tb.MetadataResult object exists and contains enough pre-fetched API data to be able // to resolve all object listed in tb to GraphQL IDs. -func fillMetadata(client *api.Client, baseRepo ghrepo.Interface, tb *IssueMetadataState) error { +func fillMetadata(client *api.Client, baseRepo ghrepo.Interface, tb *IssueMetadataState, projectV1Support gh.ProjectsV1Support) error { resolveInput := api.RepoResolveInput{} if len(tb.Assignees) > 0 && (tb.MetadataResult == nil || len(tb.MetadataResult.AssignableUsers) == 0) { @@ -71,8 +72,12 @@ func fillMetadata(client *api.Client, baseRepo ghrepo.Interface, tb *IssueMetada resolveInput.Labels = tb.Labels } - if len(tb.Projects) > 0 && (tb.MetadataResult == nil || len(tb.MetadataResult.Projects) == 0) { - resolveInput.Projects = tb.Projects + if len(tb.ProjectTitles) > 0 && (tb.MetadataResult == nil || len(tb.MetadataResult.Projects) == 0) { + if projectV1Support == gh.ProjectsV1Supported { + resolveInput.ProjectsV1 = true + } + + resolveInput.ProjectsV2 = true } if len(tb.Milestones) > 0 && (tb.MetadataResult == nil || len(tb.MetadataResult.Milestones) == 0) { @@ -93,12 +98,12 @@ func fillMetadata(client *api.Client, baseRepo ghrepo.Interface, tb *IssueMetada return nil } -func AddMetadataToIssueParams(client *api.Client, baseRepo ghrepo.Interface, params map[string]interface{}, tb *IssueMetadataState) error { +func AddMetadataToIssueParams(client *api.Client, baseRepo ghrepo.Interface, params map[string]interface{}, tb *IssueMetadataState, projectV1Support gh.ProjectsV1Support) error { if !tb.HasMetadata() { return nil } - if err := fillMetadata(client, baseRepo, tb); err != nil { + if err := fillMetadata(client, baseRepo, tb, projectV1Support); err != nil { return err } @@ -114,7 +119,7 @@ func AddMetadataToIssueParams(client *api.Client, baseRepo ghrepo.Interface, par } params["labelIds"] = labelIDs - projectIDs, projectV2IDs, err := tb.MetadataResult.ProjectsToIDs(tb.Projects) + projectIDs, projectV2IDs, err := tb.MetadataResult.ProjectsToIDs(tb.ProjectTitles) if err != nil { return fmt.Errorf("could not add to project: %w", err) } diff --git a/pkg/cmd/pr/shared/params_test.go b/pkg/cmd/pr/shared/params_test.go index 5f5e674cc0f..15f00ca4f22 100644 --- a/pkg/cmd/pr/shared/params_test.go +++ b/pkg/cmd/pr/shared/params_test.go @@ -2,13 +2,16 @@ package shared import ( "net/http" + "net/url" "reflect" "testing" "github.com/cli/cli/v2/api" + "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/pkg/httpmock" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_listURLWithQuery(t *testing.T) { @@ -265,7 +268,7 @@ func Test_WithPrAndIssueQueryParams(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := WithPrAndIssueQueryParams(nil, nil, tt.args.baseURL, tt.args.state) + got, err := WithPrAndIssueQueryParams(nil, nil, tt.args.baseURL, tt.args.state, gh.ProjectsV1Supported) if (err != nil) != tt.wantErr { t.Errorf("WithPrAndIssueQueryParams() error = %v, wantErr %v", err, tt.wantErr) return @@ -276,3 +279,144 @@ func Test_WithPrAndIssueQueryParams(t *testing.T) { }) } } + +// TODO projectsV1Deprecation +// Remove this test. +func TestWithPrAndIssueQueryParamsProjectsV1Deprecation(t *testing.T) { + t.Run("when projectsV1 is supported, requests them", func(t *testing.T) { + reg := &httpmock.Registry{} + client := api.NewClientFromHTTP(&http.Client{ + Transport: reg, + }) + + repo, _ := ghrepo.FromFullName("OWNER/REPO") + + reg.Register( + httpmock.GraphQL(`query RepositoryProjectList\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projects": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query OrganizationProjectList\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projects": { + "nodes": [ + { "name": "Triage", "id": "TRIAGEID", "resourcePath": "/orgs/ORG/projects/1" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query RepositoryProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projectsV2": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query OrganizationProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projectsV2": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query UserProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "viewer": { "projectsV2": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + + u, err := WithPrAndIssueQueryParams( + client, + repo, + "http://example.com/hey", + IssueMetadataState{ + ProjectTitles: []string{"Triage"}, + }, + gh.ProjectsV1Supported, + ) + require.NoError(t, err) + + url, err := url.Parse(u) + require.NoError(t, err) + + require.Equal( + t, + url.Query().Get("projects"), + "ORG/1", + ) + }) + + t.Run("when projectsV1 is not supported, does not request them", func(t *testing.T) { + reg := &httpmock.Registry{} + client := api.NewClientFromHTTP(&http.Client{ + Transport: reg, + }) + + repo, _ := ghrepo.FromFullName("OWNER/REPO") + + reg.Exclude( + t, + httpmock.GraphQL(`query RepositoryProjectList\b`), + ) + reg.Exclude( + t, + httpmock.GraphQL(`query OrganizationProjectList\b`), + ) + + reg.Register( + httpmock.GraphQL(`query RepositoryProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "repository": { "projectsV2": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query OrganizationProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "organization": { "projectsV2": { + "nodes": [ + { "title": "TriageV2", "id": "TRIAGEV2ID", "resourcePath": "/orgs/ORG/projects/2" } + ], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + reg.Register( + httpmock.GraphQL(`query UserProjectV2List\b`), + httpmock.StringResponse(` + { "data": { "viewer": { "projectsV2": { + "nodes": [], + "pageInfo": { "hasNextPage": false } + } } } } + `)) + + u, err := WithPrAndIssueQueryParams( + client, + repo, + "http://example.com/hey", + IssueMetadataState{ + ProjectTitles: []string{"TriageV2"}, + }, + gh.ProjectsV1Unsupported, + ) + require.NoError(t, err) + + url, err := url.Parse(u) + require.NoError(t, err) + + require.Equal( + t, + url.Query().Get("projects"), + "ORG/2", + ) + }) +} diff --git a/pkg/cmd/pr/shared/state.go b/pkg/cmd/pr/shared/state.go index 143021cb60b..7e7da436d03 100644 --- a/pkg/cmd/pr/shared/state.go +++ b/pkg/cmd/pr/shared/state.go @@ -25,12 +25,12 @@ type IssueMetadataState struct { Template string - Metadata []string - Reviewers []string - Assignees []string - Labels []string - Projects []string - Milestones []string + Metadata []string + Reviewers []string + Assignees []string + Labels []string + ProjectTitles []string + Milestones []string MetadataResult *api.RepoMetadataResult @@ -49,7 +49,7 @@ func (tb *IssueMetadataState) HasMetadata() bool { return len(tb.Reviewers) > 0 || len(tb.Assignees) > 0 || len(tb.Labels) > 0 || - len(tb.Projects) > 0 || + len(tb.ProjectTitles) > 0 || len(tb.Milestones) > 0 } diff --git a/pkg/cmd/pr/shared/survey.go b/pkg/cmd/pr/shared/survey.go index ce38535d97b..bf4476ca1ed 100644 --- a/pkg/cmd/pr/shared/survey.go +++ b/pkg/cmd/pr/shared/survey.go @@ -151,7 +151,7 @@ type RepoMetadataFetcher interface { RepoMetadataFetch(api.RepoMetadataInput) (*api.RepoMetadataResult, error) } -func MetadataSurvey(p Prompt, io *iostreams.IOStreams, baseRepo ghrepo.Interface, fetcher RepoMetadataFetcher, state *IssueMetadataState) error { +func MetadataSurvey(p Prompt, io *iostreams.IOStreams, baseRepo ghrepo.Interface, fetcher RepoMetadataFetcher, state *IssueMetadataState, projectsV1Support gh.ProjectsV1Support) error { isChosen := func(m string) bool { for _, c := range state.Metadata { if m == c { @@ -181,7 +181,8 @@ func MetadataSurvey(p Prompt, io *iostreams.IOStreams, baseRepo ghrepo.Interface Reviewers: isChosen("Reviewers"), Assignees: isChosen("Assignees"), Labels: isChosen("Labels"), - Projects: isChosen("Projects"), + ProjectsV1: isChosen("Projects") && projectsV1Support == gh.ProjectsV1Supported, + ProjectsV2: isChosen("Projects"), Milestones: isChosen("Milestone"), } metadataResult, err := fetcher.RepoMetadataFetch(metadataInput) @@ -267,7 +268,7 @@ func MetadataSurvey(p Prompt, io *iostreams.IOStreams, baseRepo ghrepo.Interface } if isChosen("Projects") { if len(projects) > 0 { - selected, err := p.MultiSelect("Projects", state.Projects, projects) + selected, err := p.MultiSelect("Projects", state.ProjectTitles, projects) if err != nil { return err } @@ -316,7 +317,7 @@ func MetadataSurvey(p Prompt, io *iostreams.IOStreams, baseRepo ghrepo.Interface state.Labels = values.Labels } if isChosen("Projects") { - state.Projects = values.Projects + state.ProjectTitles = values.Projects } if isChosen("Milestone") { if values.Milestone != "" && values.Milestone != noMilestone { diff --git a/pkg/cmd/pr/shared/survey_test.go b/pkg/cmd/pr/shared/survey_test.go index d74696460a2..6895b52ac99 100644 --- a/pkg/cmd/pr/shared/survey_test.go +++ b/pkg/cmd/pr/shared/survey_test.go @@ -1,13 +1,16 @@ package shared import ( + "errors" "testing" "github.com/cli/cli/v2/api" + "github.com/cli/cli/v2/internal/gh" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/prompter" "github.com/cli/cli/v2/pkg/iostreams" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type metadataFetcher struct { @@ -68,7 +71,7 @@ func TestMetadataSurvey_selectAll(t *testing.T) { Assignees: []string{"hubot"}, Type: PRMetadata, } - err := MetadataSurvey(pm, ios, repo, fetcher, state) + err := MetadataSurvey(pm, ios, repo, fetcher, state, gh.ProjectsV1Supported) assert.NoError(t, err) assert.Equal(t, "", stdout.String()) @@ -77,7 +80,7 @@ func TestMetadataSurvey_selectAll(t *testing.T) { assert.Equal(t, []string{"hubot"}, state.Assignees) assert.Equal(t, []string{"monalisa"}, state.Reviewers) assert.Equal(t, []string{"good first issue"}, state.Labels) - assert.Equal(t, []string{"The road to 1.0"}, state.Projects) + assert.Equal(t, []string{"The road to 1.0"}, state.ProjectTitles) assert.Equal(t, []string{}, state.Milestones) } @@ -113,7 +116,8 @@ func TestMetadataSurvey_keepExisting(t *testing.T) { state := &IssueMetadataState{ Assignees: []string{"hubot"}, } - err := MetadataSurvey(pm, ios, repo, fetcher, state) + + err := MetadataSurvey(pm, ios, repo, fetcher, state, gh.ProjectsV1Supported) assert.NoError(t, err) assert.Equal(t, "", stdout.String()) @@ -121,7 +125,64 @@ func TestMetadataSurvey_keepExisting(t *testing.T) { assert.Equal(t, []string{"hubot"}, state.Assignees) assert.Equal(t, []string{"good first issue"}, state.Labels) - assert.Equal(t, []string{"The road to 1.0"}, state.Projects) + assert.Equal(t, []string{"The road to 1.0"}, state.ProjectTitles) +} + +// TODO projectsV1Deprecation +// Remove this test and projectsV1MetadataFetcherSpy +func TestMetadataSurveyProjectV1Deprecation(t *testing.T) { + t.Run("when projectsV1 is supported, requests projectsV1", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + repo := ghrepo.New("OWNER", "REPO") + + fetcher := &projectsV1MetadataFetcherSpy{} + pm := prompter.NewMockPrompter(t) + pm.RegisterMultiSelect("What would you like to add?", []string{}, []string{"Assignees", "Labels", "Projects", "Milestone"}, func(_ string, _, options []string) ([]int, error) { + i, err := prompter.IndexFor(options, "Projects") + require.NoError(t, err) + return []int{i}, nil + }) + pm.RegisterMultiSelect("Projects", []string{}, []string{"Huge Refactoring"}, func(_ string, _, _ []string) ([]int, error) { + return []int{0}, nil + }) + + err := MetadataSurvey(pm, ios, repo, fetcher, &IssueMetadataState{}, gh.ProjectsV1Supported) + require.ErrorContains(t, err, "expected test error") + + require.True(t, fetcher.projectsV1Requested, "expected projectsV1 to be requested") + }) + + t.Run("when projectsV1 is supported, does not request projectsV1", func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + repo := ghrepo.New("OWNER", "REPO") + + fetcher := &projectsV1MetadataFetcherSpy{} + pm := prompter.NewMockPrompter(t) + pm.RegisterMultiSelect("What would you like to add?", []string{}, []string{"Assignees", "Labels", "Projects", "Milestone"}, func(_ string, _, options []string) ([]int, error) { + i, err := prompter.IndexFor(options, "Projects") + require.NoError(t, err) + return []int{i}, nil + }) + pm.RegisterMultiSelect("Projects", []string{}, []string{"Huge Refactoring"}, func(_ string, _, _ []string) ([]int, error) { + return []int{0}, nil + }) + + err := MetadataSurvey(pm, ios, repo, fetcher, &IssueMetadataState{}, gh.ProjectsV1Unsupported) + require.ErrorContains(t, err, "expected test error") + + require.False(t, fetcher.projectsV1Requested, "expected projectsV1 not to be requested") + }) +} + +type projectsV1MetadataFetcherSpy struct { + projectsV1Requested bool +} + +func (mf *projectsV1MetadataFetcherSpy) RepoMetadataFetch(input api.RepoMetadataInput) (*api.RepoMetadataResult, error) { + if input.ProjectsV1 { + mf.projectsV1Requested = true + } + return nil, errors.New("expected test error") } func TestTitledEditSurvey_cleanupHint(t *testing.T) { diff --git a/pkg/cmd/pr/status/status.go b/pkg/cmd/pr/status/status.go index eb120e5a7df..60202594f54 100644 --- a/pkg/cmd/pr/status/status.go +++ b/pkg/cmd/pr/status/status.go @@ -102,43 +102,34 @@ func statusRun(opts *StatusOptions) error { return fmt.Errorf("could not query for pull request for current branch: %w", err) } - branchConfig, err := opts.GitClient.ReadBranchConfig(ctx, currentBranchName) - if err != nil { - return err - } - // Determine if the branch is configured to merge to a special PR ref - prHeadRE := regexp.MustCompile(`^refs/pull/(\d+)/head$`) - if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil { - currentPRNumber, _ = strconv.Atoi(m[1]) - } - - if currentPRNumber == 0 { - remotes, err := opts.Remotes() + if !errors.Is(err, git.ErrNotOnAnyBranch) { + branchConfig, err := opts.GitClient.ReadBranchConfig(ctx, currentBranchName) if err != nil { return err } - // Suppressing these errors as we have other means of computing the PullRequestRefs when these fail. - parsedPushRevision, _ := opts.GitClient.ParsePushRevision(ctx, currentBranchName) - - remotePushDefault, err := opts.GitClient.RemotePushDefault(ctx) - if err != nil { - return err + // Determine if the branch is configured to merge to a special PR ref + prHeadRE := regexp.MustCompile(`^refs/pull/(\d+)/head$`) + if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil { + currentPRNumber, _ = strconv.Atoi(m[1]) } - pushDefault, err := opts.GitClient.PushDefault(ctx) - if err != nil { - return err - } + if currentPRNumber == 0 { + prRefsResolver := shared.NewPullRequestFindRefsResolver( + // We requested the branch config already, so let's cache that + shared.CachedBranchConfigGitConfigClient{ + CachedBranchConfig: branchConfig, + GitConfigClient: opts.GitClient, + }, + opts.Remotes, + ) + + prRefs, err := prRefsResolver.ResolvePullRequestRefs(baseRefRepo, "", currentBranchName) + if err != nil { + return err + } - prRefs, err := shared.ParsePRRefs(currentBranchName, branchConfig, parsedPushRevision, pushDefault, remotePushDefault, baseRefRepo, remotes) - if err != nil { - return err + currentHeadRefBranchName = prRefs.QualifiedHeadRef() } - currentHeadRefBranchName = prRefs.BranchName - } - - if err != nil { - return fmt.Errorf("could not query for pull request for current branch: %w", err) } } diff --git a/pkg/cmd/pr/status/status_test.go b/pkg/cmd/pr/status/status_test.go index c55604c2843..41c01e9150f 100644 --- a/pkg/cmd/pr/status/status_test.go +++ b/pkg/cmd/pr/status/status_test.go @@ -98,10 +98,10 @@ func TestPRStatus(t *testing.T) { // stub successful git commands rs, cleanup := run.Stub() defer cleanup(t) + rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "") rs.Register(`git config --get-regexp \^branch\\.`, 0, "") rs.Register(`git config remote.pushDefault`, 0, "") - rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") - rs.Register(`git config push.default`, 0, "") + rs.Register(`git config push.default`, 1, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -133,8 +133,8 @@ func TestPRStatus_reviewsAndChecks(t *testing.T) { defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") rs.Register(`git config remote.pushDefault`, 0, "") - rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") - rs.Register(`git config push.default`, 0, "") + rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "") + rs.Register(`git config push.default`, 1, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -166,8 +166,8 @@ func TestPRStatus_reviewsAndChecksWithStatesByCount(t *testing.T) { defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") rs.Register(`git config remote.pushDefault`, 0, "") - rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") - rs.Register(`git config push.default`, 0, "") + rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "") + rs.Register(`git config push.default`, 1, "") output, err := runCommandWithDetector(http, "blueberries", true, "", &fd.EnabledDetectorMock{}) if err != nil { @@ -198,8 +198,8 @@ func TestPRStatus_currentBranch_showTheMostRecentPR(t *testing.T) { defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") rs.Register(`git config remote.pushDefault`, 0, "") - rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") - rs.Register(`git config push.default`, 0, "") + rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "") + rs.Register(`git config push.default`, 1, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -234,8 +234,8 @@ func TestPRStatus_currentBranch_defaultBranch(t *testing.T) { defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") rs.Register(`git config remote.pushDefault`, 0, "") - rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") - rs.Register(`git config push.default`, 0, "") + rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "") + rs.Register(`git config push.default`, 1, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -276,8 +276,8 @@ func TestPRStatus_currentBranch_Closed(t *testing.T) { defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") rs.Register(`git config remote.pushDefault`, 0, "") - rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") - rs.Register(`git config push.default`, 0, "") + rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "") + rs.Register(`git config push.default`, 1, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -301,8 +301,8 @@ func TestPRStatus_currentBranch_Closed_defaultBranch(t *testing.T) { defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") rs.Register(`git config remote.pushDefault`, 0, "") - rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") - rs.Register(`git config push.default`, 0, "") + rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "") + rs.Register(`git config push.default`, 1, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -326,8 +326,8 @@ func TestPRStatus_currentBranch_Merged(t *testing.T) { defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") rs.Register(`git config remote.pushDefault`, 0, "") - rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") - rs.Register(`git config push.default`, 0, "") + rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "") + rs.Register(`git config push.default`, 1, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -351,8 +351,8 @@ func TestPRStatus_currentBranch_Merged_defaultBranch(t *testing.T) { defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") rs.Register(`git config remote.pushDefault`, 0, "") - rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") - rs.Register(`git config push.default`, 0, "") + rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "") + rs.Register(`git config push.default`, 1, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -376,8 +376,8 @@ func TestPRStatus_blankSlate(t *testing.T) { defer cleanup(t) rs.Register(`git config --get-regexp \^branch\\.`, 0, "") rs.Register(`git config remote.pushDefault`, 0, "") - rs.Register(`git rev-parse --abbrev-ref blueberries@{push}`, 0, "") - rs.Register(`git config push.default`, 0, "") + rs.Register(`git rev-parse --symbolic-full-name blueberries@{push}`, 128, "") + rs.Register(`git config push.default`, 1, "") output, err := runCommand(http, "blueberries", true, "") if err != nil { @@ -432,14 +432,6 @@ func TestPRStatus_detachedHead(t *testing.T) { defer http.Verify(t) http.Register(httpmock.GraphQL(`query PullRequestStatus\b`), httpmock.StringResponse(`{"data": {}}`)) - // stub successful git command - rs, cleanup := run.Stub() - defer cleanup(t) - rs.Register(`git config --get-regexp \^branch\\.`, 0, "") - rs.Register(`git config remote.pushDefault`, 0, "") - rs.Register(`git rev-parse --abbrev-ref @{push}`, 0, "") - rs.Register(`git config push.default`, 0, "") - output, err := runCommand(http, "", true, "") if err != nil { t.Errorf("error running command `pr status`: %v", err) diff --git a/pkg/cmd/project/shared/queries/queries.go b/pkg/cmd/project/shared/queries/queries.go index 3e63465dd08..46aa584519c 100644 --- a/pkg/cmd/project/shared/queries/queries.go +++ b/pkg/cmd/project/shared/queries/queries.go @@ -7,9 +7,7 @@ import ( "net/url" "regexp" "strings" - "time" - "github.com/briandowns/spinner" "github.com/cli/cli/v2/api" "github.com/cli/cli/v2/internal/prompter" "github.com/cli/cli/v2/pkg/iostreams" @@ -24,8 +22,8 @@ func NewClient(httpClient *http.Client, hostname string, ios *iostreams.IOStream } return &Client{ apiClient: apiClient, - spinner: ios.IsStdoutTTY() && ios.IsStderrTTY(), - prompter: prompter.New("", ios.In, ios.Out, ios.ErrOut), + io: ios, + prompter: prompter.New("", ios), } } @@ -44,9 +42,10 @@ func NewTestClient(opts ...TestClientOpt) *Client { hostname: "github.com", Client: api.NewClientFromHTTP(http.DefaultClient), } + io, _, _, _ := iostreams.Test() c := &Client{ apiClient: apiClient, - spinner: false, + io: io, prompter: nil, } @@ -80,7 +79,7 @@ type graphqlClient interface { type Client struct { apiClient graphqlClient - spinner bool + io *iostreams.IOStreams prompter iprompter } @@ -89,19 +88,12 @@ const ( LimitMax = 100 // https://docs.github.com/en/graphql/overview/resource-limitations#node-limit ) -// doQuery wraps API calls with a visual spinner -func (c *Client) doQuery(name string, query interface{}, variables map[string]interface{}) error { - var sp *spinner.Spinner - if c.spinner { - // https://github.com/briandowns/spinner#available-character-sets - dotStyle := spinner.CharSets[11] - sp = spinner.New(dotStyle, 120*time.Millisecond, spinner.WithColor("fgCyan")) - sp.Start() - } +// doQueryWithProgressIndicator wraps API calls with a progress indicator. +// The query name is used in the progress indicator label. +func (c *Client) doQueryWithProgressIndicator(name string, query interface{}, variables map[string]interface{}) error { + c.io.StartProgressIndicatorWithLabel(fmt.Sprintf("Fetching %s", name)) + defer c.io.StopProgressIndicator() err := c.apiClient.Query(name, query, variables) - if sp != nil { - sp.Stop() - } return handleError(err) } @@ -552,7 +544,7 @@ func (c *Client) ProjectItems(o *Owner, number int32, limit int) (*Project, erro query = &viewerOwnerWithItems{} // must be a pointer to work with graphql queries queryName = "ViewerProjectWithItems" } - err := c.doQuery(queryName, query, variables) + err := c.doQueryWithProgressIndicator(queryName, query, variables) if err != nil { return project, err } @@ -706,7 +698,7 @@ func paginateAttributes[N projectAttribute](c *Client, p pager[N], variables map // set the cursor to the end of the last page variables[afterKey] = (*githubv4.String)(&cursor) - err := c.doQuery(queryName, p, variables) + err := c.doQueryWithProgressIndicator(queryName, p, variables) if err != nil { return nodes, err } @@ -863,7 +855,7 @@ func (c *Client) ProjectFields(o *Owner, number int32, limit int) (*Project, err query = &viewerOwnerWithFields{} // must be a pointer to work with graphql queries queryName = "ViewerProjectWithFields" } - err := c.doQuery(queryName, query, variables) + err := c.doQueryWithProgressIndicator(queryName, query, variables) if err != nil { return project, err } @@ -977,7 +969,7 @@ const ViewerOwner OwnerType = "VIEWER" // ViewerLoginName returns the login name of the viewer. func (c *Client) ViewerLoginName() (string, error) { var query viewerLogin - err := c.doQuery("Viewer", &query, map[string]interface{}{}) + err := c.doQueryWithProgressIndicator("Viewer", &query, map[string]interface{}{}) if err != nil { return "", err } @@ -988,7 +980,7 @@ func (c *Client) ViewerLoginName() (string, error) { func (c *Client) OwnerIDAndType(login string) (string, OwnerType, error) { if login == "@me" || login == "" { var query viewerLogin - err := c.doQuery("ViewerOwner", &query, nil) + err := c.doQueryWithProgressIndicator("ViewerOwner", &query, nil) if err != nil { return "", "", err } @@ -1009,7 +1001,7 @@ func (c *Client) OwnerIDAndType(login string) (string, OwnerType, error) { } `graphql:"organization(login: $login)"` } - err := c.doQuery("UserOrgOwner", &query, variables) + err := c.doQueryWithProgressIndicator("UserOrgOwner", &query, variables) if err != nil { // Due to the way the queries are structured, we don't know if a login belongs to a user // or to an org, even though they are unique. To deal with this, we try both - if neither @@ -1052,7 +1044,7 @@ func (c *Client) IssueOrPullRequestID(rawURL string) (string, error) { "url": githubv4.URI{URL: uri}, } var query issueOrPullRequest - err = c.doQuery("GetIssueOrPullRequest", &query, variables) + err = c.doQueryWithProgressIndicator("GetIssueOrPullRequest", &query, variables) if err != nil { return "", err } @@ -1114,7 +1106,7 @@ func (c *Client) userOrgLogins() ([]loginTypes, error) { "after": (*githubv4.String)(nil), } - err := c.doQuery("ViewerLoginAndOrgs", &v, variables) + err := c.doQueryWithProgressIndicator("ViewerLoginAndOrgs", &v, variables) if err != nil { return l, err } @@ -1152,7 +1144,7 @@ func (c *Client) paginateOrgLogins(l []loginTypes, cursor string) ([]loginTypes, "after": githubv4.String(cursor), } - err := c.doQuery("ViewerLoginAndOrgs", &v, variables) + err := c.doQueryWithProgressIndicator("ViewerLoginAndOrgs", &v, variables) if err != nil { return l, err } @@ -1247,16 +1239,16 @@ func (c *Client) NewProject(canPrompt bool, o *Owner, number int32, fields bool) if o.Type == UserOwner { var query userOwner variables["login"] = githubv4.String(o.Login) - err := c.doQuery("UserProject", &query, variables) + err := c.doQueryWithProgressIndicator("UserProject", &query, variables) return &query.Owner.Project, err } else if o.Type == OrgOwner { variables["login"] = githubv4.String(o.Login) var query orgOwner - err := c.doQuery("OrgProject", &query, variables) + err := c.doQueryWithProgressIndicator("OrgProject", &query, variables) return &query.Owner.Project, err } else if o.Type == ViewerOwner { var query viewerOwner - err := c.doQuery("ViewerProject", &query, variables) + err := c.doQueryWithProgressIndicator("ViewerProject", &query, variables) return &query.Owner.Project, err } return nil, errors.New("unknown owner type") @@ -1331,7 +1323,7 @@ func (c *Client) Projects(login string, t OwnerType, limit int, fields bool) (Pr // the cost. if t == UserOwner { var query userProjects - if err := c.doQuery("UserProjects", &query, variables); err != nil { + if err := c.doQueryWithProgressIndicator("UserProjects", &query, variables); err != nil { return projects, err } projects.Nodes = append(projects.Nodes, query.Owner.Projects.Nodes...) @@ -1340,7 +1332,7 @@ func (c *Client) Projects(login string, t OwnerType, limit int, fields bool) (Pr projects.TotalCount = query.Owner.Projects.TotalCount } else if t == OrgOwner { var query orgProjects - if err := c.doQuery("OrgProjects", &query, variables); err != nil { + if err := c.doQueryWithProgressIndicator("OrgProjects", &query, variables); err != nil { return projects, err } projects.Nodes = append(projects.Nodes, query.Owner.Projects.Nodes...) @@ -1349,7 +1341,7 @@ func (c *Client) Projects(login string, t OwnerType, limit int, fields bool) (Pr projects.TotalCount = query.Owner.Projects.TotalCount } else if t == ViewerOwner { var query viewerProjects - if err := c.doQuery("ViewerProjects", &query, variables); err != nil { + if err := c.doQueryWithProgressIndicator("ViewerProjects", &query, variables); err != nil { return projects, err } projects.Nodes = append(projects.Nodes, query.Owner.Projects.Nodes...) diff --git a/pkg/cmd/root/help_topic.go b/pkg/cmd/root/help_topic.go index b85d64ca37f..0c9534306aa 100644 --- a/pkg/cmd/root/help_topic.go +++ b/pkg/cmd/root/help_topic.go @@ -84,6 +84,8 @@ var HelpTopics = []helpTopic{ %[1]sGH_COLOR_LABELS%[1]s: set to any value to display labels using their RGB hex color codes in terminals that support truecolor. + %[1]sGH_ACCESSIBLE_COLORS%[1]s (preview): set to a truthy value to use customizable, 4-bit accessible colors. + %[1]sGH_FORCE_TTY%[1]s: set to any value to force terminal-style output even when the output is redirected. When the value is a number, it is interpreted as the number of columns available in the viewport. When the value is a percentage, it will be applied against @@ -114,6 +116,9 @@ var HelpTopics = []helpTopic{ %[1]sGH_ACCESSIBLE_PROMPTER%[1]s (preview): set to a truthy value to enable prompts that are more compatible with speech synthesis and braille screen readers. + + %[1]sGH_SPINNER_DISABLED%[1]s: set to a truthy value to replace the spinner animation with + a textual progress indicator. `, "`"), }, { diff --git a/pkg/httpmock/registry.go b/pkg/httpmock/registry.go index 387d0fc9560..b7c5a117df8 100644 --- a/pkg/httpmock/registry.go +++ b/pkg/httpmock/registry.go @@ -3,10 +3,10 @@ package httpmock import ( "fmt" "net/http" + "runtime/debug" + "strings" "sync" "testing" - - "github.com/stretchr/testify/assert" ) // Replace http.Client transport layer with registry so all requests get @@ -23,16 +23,28 @@ type Registry struct { func (r *Registry) Register(m Matcher, resp Responder) { r.stubs = append(r.stubs, &Stub{ + Stack: string(debug.Stack()), Matcher: m, Responder: resp, }) } func (r *Registry) Exclude(t *testing.T, m Matcher) { + registrationStack := string(debug.Stack()) + excludedStub := &Stub{ Matcher: m, Responder: func(req *http.Request) (*http.Response, error) { - assert.FailNowf(t, "Exclude error", "API called when excluded: %v", req.URL) + callStack := string(debug.Stack()) + + var errMsg strings.Builder + errMsg.WriteString("HTTP call was made when it should have been excluded:\n") + errMsg.WriteString(fmt.Sprintf("Request URL: %s\n", req.URL)) + errMsg.WriteString(fmt.Sprintf("Was excluded by: %s\n", registrationStack)) + errMsg.WriteString(fmt.Sprintf("Was called from: %s\n", callStack)) + + t.Error(errMsg.String()) + t.FailNow() return nil, nil }, exclude: true, @@ -46,17 +58,24 @@ type Testing interface { } func (r *Registry) Verify(t Testing) { - n := 0 + var unmatchedStubStacks []string for _, s := range r.stubs { if !s.matched && !s.exclude { - n++ + unmatchedStubStacks = append(unmatchedStubStacks, s.Stack) } } - if n > 0 { + if len(unmatchedStubStacks) > 0 { t.Helper() - // NOTE: stubs offer no useful reflection, so we can't print details + stacks := strings.Builder{} + for i, stack := range unmatchedStubStacks { + stacks.WriteString(fmt.Sprintf("Stub %d:\n", i+1)) + stacks.WriteString(fmt.Sprintf("\t%s", stack)) + if stack != unmatchedStubStacks[len(unmatchedStubStacks)-1] { + stacks.WriteString("\n") + } + } // about dead stubs and what they were trying to match - t.Errorf("%d unmatched HTTP stubs", n) + t.Errorf("%d HTTP stubs unmatched, stacks:\n%s", len(unmatchedStubStacks), stacks.String()) } } @@ -84,7 +103,7 @@ func (r *Registry) RoundTrip(req *http.Request) (*http.Response, error) { if stub == nil { r.mu.Unlock() - return nil, fmt.Errorf("no registered stubs matched %v", req) + return nil, fmt.Errorf("no registered HTTP stubs matched %v", req) } r.Requests = append(r.Requests, req) diff --git a/pkg/httpmock/stub.go b/pkg/httpmock/stub.go index 4e61d12f44f..745c1241743 100644 --- a/pkg/httpmock/stub.go +++ b/pkg/httpmock/stub.go @@ -15,6 +15,7 @@ type Matcher func(req *http.Request) bool type Responder func(req *http.Request) (*http.Response, error) type Stub struct { + Stack string matched bool Matcher Matcher Responder Responder diff --git a/pkg/iostreams/iostreams.go b/pkg/iostreams/iostreams.go index f5e3c2aee39..22f966ac810 100644 --- a/pkg/iostreams/iostreams.go +++ b/pkg/iostreams/iostreams.go @@ -58,6 +58,7 @@ type IOStreams struct { progressIndicatorEnabled bool progressIndicator *spinner.Spinner progressIndicatorMu sync.Mutex + spinnerDisabled bool alternateScreenBufferEnabled bool alternateScreenBufferActive bool @@ -78,7 +79,8 @@ type IOStreams struct { pagerCommand string pagerProcess *os.Process - neverPrompt bool + neverPrompt bool + accessiblePrompterEnabled bool TempFileOverride *os.File } @@ -273,6 +275,14 @@ func (s *IOStreams) SetNeverPrompt(v bool) { s.neverPrompt = v } +func (s *IOStreams) GetSpinnerDisabled() bool { + return s.spinnerDisabled +} + +func (s *IOStreams) SetSpinnerDisabled(v bool) { + s.spinnerDisabled = v +} + func (s *IOStreams) StartProgressIndicator() { s.StartProgressIndicatorWithLabel("") } @@ -282,6 +292,15 @@ func (s *IOStreams) StartProgressIndicatorWithLabel(label string) { return } + if s.spinnerDisabled { + // If the spinner is disabled, simply print a + // textual progress indicator and return. + // This means that s.ProgressIndicator will be nil. + // See also: the comment on StopProgressIndicator() + s.startTextualProgressIndicator(label) + return + } + s.progressIndicatorMu.Lock() defer s.progressIndicatorMu.Unlock() @@ -295,8 +314,10 @@ func (s *IOStreams) StartProgressIndicatorWithLabel(label string) { } // https://github.com/briandowns/spinner#available-character-sets - dotStyle := spinner.CharSets[11] - sp := spinner.New(dotStyle, 120*time.Millisecond, spinner.WithWriter(s.ErrOut), spinner.WithColor("fgCyan")) + // ⣾ ⣷ ⣽ ⣻ ⡿ + spinnerStyle := spinner.CharSets[11] + + sp := spinner.New(spinnerStyle, 120*time.Millisecond, spinner.WithWriter(s.ErrOut), spinner.WithColor("fgCyan")) if label != "" { sp.Prefix = label + " " } @@ -305,6 +326,27 @@ func (s *IOStreams) StartProgressIndicatorWithLabel(label string) { s.progressIndicator = sp } +func (s *IOStreams) startTextualProgressIndicator(label string) { + s.progressIndicatorMu.Lock() + defer s.progressIndicatorMu.Unlock() + + // Default label when spinner disabled is "Working..." + if label == "" { + label = "Working..." + } + + // Add an ellipsis to the label if it doesn't already have one. + ellipsis := "..." + if !strings.HasSuffix(label, ellipsis) { + label = label + ellipsis + } + + fmt.Fprintf(s.ErrOut, "%s%s", s.ColorScheme().Cyan(label), "\n") +} + +// StopProgressIndicator stops the progress indicator if it is running. +// Note that a textual progess indicator does not create a progress indicator, +// so this method is a no-op in that case. func (s *IOStreams) StopProgressIndicator() { s.progressIndicatorMu.Lock() defer s.progressIndicatorMu.Unlock() @@ -416,6 +458,14 @@ func (s *IOStreams) AccessibleColorsEnabled() bool { return s.accessibleColorsEnabled } +func (s *IOStreams) SetAccessiblePrompterEnabled(enabled bool) { + s.accessiblePrompterEnabled = enabled +} + +func (s *IOStreams) AccessiblePrompterEnabled() bool { + return s.accessiblePrompterEnabled +} + func System() *IOStreams { terminal := ghTerm.FromEnv() diff --git a/pkg/iostreams/iostreams_progress_indicator_test.go b/pkg/iostreams/iostreams_progress_indicator_test.go new file mode 100644 index 00000000000..60d0ece91e3 --- /dev/null +++ b/pkg/iostreams/iostreams_progress_indicator_test.go @@ -0,0 +1,254 @@ +//go:build !windows + +package iostreams + +import ( + "fmt" + "io" + "os" + "strings" + "testing" + "time" + + "github.com/Netflix/go-expect" + "github.com/creack/pty" + "github.com/hinshun/vt10x" + "github.com/stretchr/testify/require" +) + +func TestStartProgressIndicatorWithLabel(t *testing.T) { + osOut := os.Stdout + defer func() { os.Stdout = osOut }() + // Why do we need a channel in these tests to implement a timeout instead of + // relying on expect's timeout? + // + // Well, expect's timeout is based on the maximum time of a single read + // from the console. This works in cases like prompting where we block + // waiting for input because the console is not ready to be read. + // But in this case, we are not blocking waiting for input and stdout + // can be constantly read. This means the timeout will never be reached + // in the event of a expectation failure. + // To fix this, we need to implement our own timeout that is based + // specifically on the total time spent reading the console and waiting + // for the target string instead of the max time for a single read + // from the console. + t.Run("progress indicator respects GH_SPINNER_DISABLED is true", func(t *testing.T) { + console := newTestVirtualTerminal(t) + io := newTestIOStreams(t, console, true) + + done := make(chan error) + + go func() { + _, err := console.ExpectString("Working...") + done <- err + }() + + io.StartProgressIndicatorWithLabel("") + defer io.StopProgressIndicator() + + select { + case err := <-done: + require.NoError(t, err) + case <-time.After(2 * time.Second): + t.Fatal("Test timed out waiting for progress indicator") + } + }) + + t.Run("progress indicator respects GH_SPINNER_DISABLED is false", func(t *testing.T) { + console := newTestVirtualTerminal(t) + io := newTestIOStreams(t, console, false) + + done := make(chan error) + + go func() { + _, err := console.ExpectString("⣾") + done <- err + }() + + io.StartProgressIndicatorWithLabel("") + defer io.StopProgressIndicator() + + select { + case err := <-done: + require.NoError(t, err) + case <-time.After(2 * time.Second): + t.Fatal("Test timed out waiting for progress indicator") + } + }) + + t.Run("progress indicator with GH_SPINNER_DISABLED shows label", func(t *testing.T) { + console := newTestVirtualTerminal(t) + io := newTestIOStreams(t, console, true) + progressIndicatorLabel := "downloading happiness" + + done := make(chan error) + + go func() { + _, err := console.ExpectString(progressIndicatorLabel + "...") + done <- err + }() + + io.StartProgressIndicatorWithLabel(progressIndicatorLabel) + defer io.StopProgressIndicator() + + select { + case err := <-done: + require.NoError(t, err) + case <-time.After(2 * time.Second): + t.Fatal("Test timed out waiting for progress indicator") + } + }) + + t.Run("progress indicator shows label and spinner", func(t *testing.T) { + console := newTestVirtualTerminal(t) + io := newTestIOStreams(t, console, false) + progressIndicatorLabel := "downloading happiness" + + done := make(chan error) + + go func() { + _, err := console.ExpectString(progressIndicatorLabel) + require.NoError(t, err) + _, err = console.ExpectString("⣾") + done <- err + }() + + io.StartProgressIndicatorWithLabel(progressIndicatorLabel) + defer io.StopProgressIndicator() + + select { + case err := <-done: + require.NoError(t, err) + case <-time.After(2 * time.Second): + t.Fatal("Test timed out waiting for progress indicator") + } + }) + + t.Run("multiple calls to start progress indicator with GH_SPINNER_DISABLED prints additional labels", func(t *testing.T) { + console := newTestVirtualTerminal(t) + io := newTestIOStreams(t, console, true) + progressIndicatorLabel1 := "downloading happiness" + progressIndicatorLabel2 := "downloading sadness" + done := make(chan error) + go func() { + _, err := console.ExpectString(progressIndicatorLabel1 + "...") + require.NoError(t, err) + _, err = console.ExpectString(progressIndicatorLabel2 + "...") + done <- err + }() + io.StartProgressIndicatorWithLabel(progressIndicatorLabel1) + defer io.StopProgressIndicator() + io.StartProgressIndicatorWithLabel(progressIndicatorLabel2) + + select { + case err := <-done: + require.NoError(t, err) + case <-time.After(2 * time.Second): + t.Fatal("Test timed out waiting for progress indicator") + } + }) +} + +func newTestVirtualTerminal(t *testing.T) *expect.Console { + t.Helper() + + // Create a PTY and hook up a virtual terminal emulator + ptm, pts, err := pty.Open() + require.NoError(t, err) + + term := vt10x.New(vt10x.WithWriter(pts)) + + // Create a console via Expect that allows scripting against the terminal + consoleOpts := []expect.ConsoleOpt{ + expect.WithStdin(ptm), + expect.WithStdout(term), + expect.WithCloser(ptm, pts), + failOnExpectError(t), + failOnSendError(t), + expect.WithDefaultTimeout(time.Second), + } + + console, err := expect.NewConsole(consoleOpts...) + require.NoError(t, err) + t.Cleanup(func() { testCloser(t, console) }) + + return console +} + +func newTestIOStreams(t *testing.T, console *expect.Console, spinnerDisabled bool) *IOStreams { + t.Helper() + + in := console.Tty() + out := console.Tty() + errOut := console.Tty() + + // Because the briandowns/spinner checks os.Stdout directly, + // we need this hack to trick it into allowing the spinner to print... + os.Stdout = out + + io := &IOStreams{ + In: in, + Out: out, + ErrOut: errOut, + term: fakeTerm{}, + } + io.progressIndicatorEnabled = true + io.SetSpinnerDisabled(spinnerDisabled) + return io +} + +// failOnExpectError adds an observer that will fail the test in a standardised way +// if any expectation on the command output fails, without requiring an explicit +// assertion. +// +// Use WithRelaxedIO to disable this behaviour. +func failOnExpectError(t *testing.T) expect.ConsoleOpt { + t.Helper() + return expect.WithExpectObserver( + func(matchers []expect.Matcher, buf string, err error) { + t.Helper() + + if err == nil { + return + } + + if len(matchers) == 0 { + t.Fatalf("Error occurred while matching %q: %s\n", buf, err) + } + + var criteria []string + for _, matcher := range matchers { + criteria = append(criteria, fmt.Sprintf("%q", matcher.Criteria())) + } + t.Fatalf("Failed to find [%s] in %q: %s\n", strings.Join(criteria, ", "), buf, err) + }, + ) +} + +// failOnSendError adds an observer that will fail the test in a standardised way +// if any sending of input fails, without requiring an explicit assertion. +// +// Use WithRelaxedIO to disable this behaviour. +func failOnSendError(t *testing.T) expect.ConsoleOpt { + t.Helper() + return expect.WithSendObserver( + func(msg string, n int, err error) { + t.Helper() + + if err != nil { + t.Fatalf("Failed to send %q: %s\n", msg, err) + } + if len(msg) != n { + t.Fatalf("Only sent %d of %d bytes for %q\n", n, len(msg), msg) + } + }, + ) +} + +// testCloser is a helper to fail the test if a Closer fails to close. +func testCloser(t *testing.T, closer io.Closer) { + t.Helper() + if err := closer.Close(); err != nil { + t.Errorf("Close failed: %s", err) + } +} diff --git a/pkg/option/option.go b/pkg/option/option.go index 8d3b70f3f7c..caf26dd0b32 100644 --- a/pkg/option/option.go +++ b/pkg/option/option.go @@ -46,6 +46,15 @@ func None[T any]() Option[T] { return Option[T]{} } +func SomeIfNonZero[T comparable](value T) Option[T] { + // value is a zero value then return a None + var zero T + if value == zero { + return None[T]() + } + return Some(value) +} + // String implements the [fmt.Stringer] interface. func (o Option[T]) String() string { if o.present { diff --git a/pkg/search/result.go b/pkg/search/result.go index 0c7c43cd7c5..0b9d1ab168e 100644 --- a/pkg/search/result.go +++ b/pkg/search/result.go @@ -93,25 +93,29 @@ var PullRequestFields = append(IssueFields, type CodeResult struct { IncompleteResults bool `json:"incomplete_results"` Items []Code `json:"items"` - Total int `json:"total_count"` + // Number of code search results matching the query on the server. Ignoring limit. + Total int `json:"total_count"` } type CommitsResult struct { IncompleteResults bool `json:"incomplete_results"` Items []Commit `json:"items"` - Total int `json:"total_count"` + // Number of commits matching the query on the server. Ignoring limit. + Total int `json:"total_count"` } type RepositoriesResult struct { IncompleteResults bool `json:"incomplete_results"` Items []Repository `json:"items"` - Total int `json:"total_count"` + // Number of repositories matching the query on the server. Ignoring limit. + Total int `json:"total_count"` } type IssuesResult struct { IncompleteResults bool `json:"incomplete_results"` Items []Issue `json:"items"` - Total int `json:"total_count"` + // Number of isssues matching the query on the server. Ignoring limit. + Total int `json:"total_count"` } type Code struct { diff --git a/pkg/search/searcher.go b/pkg/search/searcher.go index 4168dc7f3a5..7cbd355623b 100644 --- a/pkg/search/searcher.go +++ b/pkg/search/searcher.go @@ -14,6 +14,7 @@ import ( ) const ( + // GitHub API has a limit of 100 per page maxPerPage = 100 orderKey = "order" sortKey = "sort" @@ -60,100 +61,145 @@ func NewSearcher(client *http.Client, host string) Searcher { func (s searcher) Code(query Query) (CodeResult, error) { result := CodeResult{} - toRetrieve := query.Limit + var resp *http.Response var err error - for toRetrieve > 0 { - query.Limit = min(toRetrieve, maxPerPage) + + // We will request either the query limit if it's less than 1 page, or our max page size. + // This number doesn't change to keep a valid offset. + // + // For example, say we want 150 items out of 500. + // We request page #1 for 100 items and get items 0 to 99. + // Then we request page #2 for 100 items, we get items 100 to 199 and only keep 100 to 149. + // If we were to request page #2 for 50 items, we would instead get items 50 to 99. + numItemsToRetrieve := query.Limit + query.Limit = min(numItemsToRetrieve, maxPerPage) + + for numItemsToRetrieve > 0 { query.Page = nextPage(resp) if query.Page == 0 { break } + page := CodeResult{} resp, err = s.search(query, &page) if err != nil { return result, err } + + // If we're going to reach the requested limit, only add that many items, + // otherwise add all the results. + numItemsToAdd := min(len(page.Items), numItemsToRetrieve) result.IncompleteResults = page.IncompleteResults + // The API returns how many items match the query in every response. + // With the example above, this would be 500. result.Total = page.Total - result.Items = append(result.Items, page.Items...) - toRetrieve = toRetrieve - len(page.Items) + result.Items = append(result.Items, page.Items[:numItemsToAdd]...) + numItemsToRetrieve = numItemsToRetrieve - numItemsToAdd } + return result, nil } func (s searcher) Commits(query Query) (CommitsResult, error) { result := CommitsResult{} - toRetrieve := query.Limit + var resp *http.Response var err error - for toRetrieve > 0 { - query.Limit = min(toRetrieve, maxPerPage) + + numItemsToRetrieve := query.Limit + query.Limit = min(numItemsToRetrieve, maxPerPage) + + for numItemsToRetrieve > 0 { query.Page = nextPage(resp) if query.Page == 0 { break } + page := CommitsResult{} resp, err = s.search(query, &page) if err != nil { return result, err } + + numItemsToAdd := min(len(page.Items), numItemsToRetrieve) result.IncompleteResults = page.IncompleteResults result.Total = page.Total - result.Items = append(result.Items, page.Items...) - toRetrieve = toRetrieve - len(page.Items) + result.Items = append(result.Items, page.Items[:numItemsToAdd]...) + numItemsToRetrieve = numItemsToRetrieve - numItemsToAdd } return result, nil } func (s searcher) Repositories(query Query) (RepositoriesResult, error) { result := RepositoriesResult{} - toRetrieve := query.Limit + var resp *http.Response var err error - for toRetrieve > 0 { - query.Limit = min(toRetrieve, maxPerPage) + + numItemsToRetrieve := query.Limit + query.Limit = min(numItemsToRetrieve, maxPerPage) + + for numItemsToRetrieve > 0 { query.Page = nextPage(resp) if query.Page == 0 { break } + page := RepositoriesResult{} resp, err = s.search(query, &page) if err != nil { return result, err } + + numItemsToAdd := min(len(page.Items), numItemsToRetrieve) result.IncompleteResults = page.IncompleteResults result.Total = page.Total - result.Items = append(result.Items, page.Items...) - toRetrieve = toRetrieve - len(page.Items) + result.Items = append(result.Items, page.Items[:numItemsToAdd]...) + numItemsToRetrieve = numItemsToRetrieve - numItemsToAdd } return result, nil } func (s searcher) Issues(query Query) (IssuesResult, error) { result := IssuesResult{} - toRetrieve := query.Limit + var resp *http.Response var err error - for toRetrieve > 0 { - query.Limit = min(toRetrieve, maxPerPage) + + numItemsToRetrieve := query.Limit + query.Limit = min(numItemsToRetrieve, maxPerPage) + + for numItemsToRetrieve > 0 { query.Page = nextPage(resp) if query.Page == 0 { break } + page := IssuesResult{} resp, err = s.search(query, &page) if err != nil { return result, err } + + numItemsToAdd := min(len(page.Items), numItemsToRetrieve) result.IncompleteResults = page.IncompleteResults result.Total = page.Total - result.Items = append(result.Items, page.Items...) - toRetrieve = toRetrieve - len(page.Items) + result.Items = append(result.Items, page.Items[:numItemsToAdd]...) + numItemsToRetrieve = numItemsToRetrieve - numItemsToAdd } return result, nil } +// search makes a single-page REST search request for code, commits, issues, prs, or repos. +// +// The result argument is populated with the following information: +// +// - Total: the number of search results matching the query, which may exceed the number of items returned +// - IncompleteResults: whether the search request exceeded search time limit, potentially being incomplete +// - Items: the actual matching search results, up to 100 max items per page +// +// For more information, see https://docs.github.com/en/rest/search/search?apiVersion=2022-11-28. func (s searcher) search(query Query, result interface{}) (*http.Response, error) { path := fmt.Sprintf("%ssearch/%s", ghinstance.RESTPrefix(s.host), query.Kind) qs := url.Values{} @@ -236,10 +282,15 @@ func handleHTTPError(resp *http.Response) error { return httpError } +// https://docs.github.com/en/rest/using-the-rest-api/using-pagination-in-the-rest-api func nextPage(resp *http.Response) (page int) { if resp == nil { return 1 } + + // When using pagination, responses get a "Link" field in their header. + // When a next page is available, "Link" contains a link to the next page + // tagged with rel="next". for _, m := range linkRE.FindAllStringSubmatch(resp.Header.Get("Link"), -1) { if !(len(m) > 2 && m[2] == "next") { continue diff --git a/pkg/search/searcher_test.go b/pkg/search/searcher_test.go index 8642feed097..e893c9a3b92 100644 --- a/pkg/search/searcher_test.go +++ b/pkg/search/searcher_test.go @@ -1,8 +1,10 @@ package search import ( + "fmt" "net/http" "net/url" + "strconv" "testing" "github.com/MakeNowJust/heredoc" @@ -46,10 +48,14 @@ func TestSearcherCode(t *testing.T) { httpStubs: func(reg *httpmock.Registry) { reg.Register( httpmock.QueryMatcher("GET", "search/code", values), - httpmock.JSONResponse(CodeResult{ - IncompleteResults: false, - Items: []Code{{Name: "file.go"}}, - Total: 1, + httpmock.JSONResponse(map[string]interface{}{ + "incomplete_results": false, + "total_count": 1, + "items": []interface{}{ + map[string]interface{}{ + "name": "file.go", + }, + }, }), ) }, @@ -66,10 +72,14 @@ func TestSearcherCode(t *testing.T) { httpStubs: func(reg *httpmock.Registry) { reg.Register( httpmock.QueryMatcher("GET", "api/v3/search/code", values), - httpmock.JSONResponse(CodeResult{ - IncompleteResults: false, - Items: []Code{{Name: "file.go"}}, - Total: 1, + httpmock.JSONResponse(map[string]interface{}{ + "incomplete_results": false, + "total_count": 1, + "items": []interface{}{ + map[string]interface{}{ + "name": "file.go", + }, + }, }), ) }, @@ -84,25 +94,83 @@ func TestSearcherCode(t *testing.T) { }, httpStubs: func(reg *httpmock.Registry) { firstReq := httpmock.QueryMatcher("GET", "search/code", values) - firstRes := httpmock.JSONResponse(CodeResult{ - IncompleteResults: false, - Items: []Code{{Name: "file.go"}}, - Total: 2, + firstRes := httpmock.JSONResponse(map[string]interface{}{ + "incomplete_results": false, + "total_count": 2, + "items": []interface{}{ + map[string]interface{}{ + "name": "file.go", + }, + }, + }) + firstRes = httpmock.WithHeader(firstRes, "Link", `; rel="next"`) + secondReq := httpmock.QueryMatcher("GET", "search/code", url.Values{ + "page": []string{"2"}, + "per_page": []string{"30"}, + "q": []string{"keyword language:go"}, + }) + secondRes := httpmock.JSONResponse(map[string]interface{}{ + "incomplete_results": false, + "total_count": 2, + "items": []interface{}{ + map[string]interface{}{ + "name": "file2.go", + }, + }, + }) + reg.Register(firstReq, firstRes) + reg.Register(secondReq, secondRes) + }, + }, + { + name: "collect full and partial pages under total number of matching search results", + query: Query{ + Keywords: []string{"keyword"}, + Kind: "code", + Limit: 110, + Qualifiers: Qualifiers{ + Language: "go", }, - ) + }, + result: CodeResult{ + IncompleteResults: false, + Items: initialize(0, 110, func(i int) Code { + return Code{ + Name: fmt.Sprintf("name%d.go", i), + } + }), + Total: 287, + }, + httpStubs: func(reg *httpmock.Registry) { + firstReq := httpmock.QueryMatcher("GET", "search/code", url.Values{ + "page": []string{"1"}, + "per_page": []string{"100"}, + "q": []string{"keyword language:go"}, + }) + firstRes := httpmock.JSONResponse(map[string]interface{}{ + "incomplete_results": false, + "total_count": 287, + "items": initialize(0, 100, func(i int) interface{} { + return map[string]interface{}{ + "name": fmt.Sprintf("name%d.go", i), + } + }), + }) firstRes = httpmock.WithHeader(firstRes, "Link", `; rel="next"`) secondReq := httpmock.QueryMatcher("GET", "search/code", url.Values{ "page": []string{"2"}, - "per_page": []string{"29"}, + "per_page": []string{"100"}, "q": []string{"keyword language:go"}, - }, - ) - secondRes := httpmock.JSONResponse(CodeResult{ - IncompleteResults: false, - Items: []Code{{Name: "file2.go"}}, - Total: 2, - }, - ) + }) + secondRes := httpmock.JSONResponse(map[string]interface{}{ + "incomplete_results": false, + "total_count": 287, + "items": initialize(100, 200, func(i int) interface{} { + return map[string]interface{}{ + "name": fmt.Sprintf("name%d.go", i), + } + }), + }) reg.Register(firstReq, firstRes) reg.Register(secondReq, secondRes) }, @@ -201,10 +269,14 @@ func TestSearcherCommits(t *testing.T) { httpStubs: func(reg *httpmock.Registry) { reg.Register( httpmock.QueryMatcher("GET", "search/commits", values), - httpmock.JSONResponse(CommitsResult{ - IncompleteResults: false, - Items: []Commit{{Sha: "abc"}}, - Total: 1, + httpmock.JSONResponse(map[string]interface{}{ + "incomplete_results": false, + "total_count": 1, + "items": []interface{}{ + map[string]interface{}{ + "sha": "abc", + }, + }, }), ) }, @@ -221,10 +293,14 @@ func TestSearcherCommits(t *testing.T) { httpStubs: func(reg *httpmock.Registry) { reg.Register( httpmock.QueryMatcher("GET", "api/v3/search/commits", values), - httpmock.JSONResponse(CommitsResult{ - IncompleteResults: false, - Items: []Commit{{Sha: "abc"}}, - Total: 1, + httpmock.JSONResponse(map[string]interface{}{ + "incomplete_results": false, + "total_count": 1, + "items": []interface{}{ + map[string]interface{}{ + "sha": "abc", + }, + }, }), ) }, @@ -239,27 +315,92 @@ func TestSearcherCommits(t *testing.T) { }, httpStubs: func(reg *httpmock.Registry) { firstReq := httpmock.QueryMatcher("GET", "search/commits", values) - firstRes := httpmock.JSONResponse(CommitsResult{ - IncompleteResults: false, - Items: []Commit{{Sha: "abc"}}, - Total: 2, + firstRes := httpmock.JSONResponse(map[string]interface{}{ + "incomplete_results": false, + "total_count": 2, + "items": []interface{}{ + map[string]interface{}{ + "sha": "abc", + }, + }, + }) + firstRes = httpmock.WithHeader(firstRes, "Link", `; rel="next"`) + secondReq := httpmock.QueryMatcher("GET", "search/commits", url.Values{ + "page": []string{"2"}, + "per_page": []string{"30"}, + "order": []string{"desc"}, + "sort": []string{"committer-date"}, + "q": []string{"keyword author:foobar committer-date:>2021-02-28"}, + }) + secondRes := httpmock.JSONResponse(map[string]interface{}{ + "incomplete_results": false, + "total_count": 2, + "items": []interface{}{ + map[string]interface{}{ + "sha": "def", + }, + }, + }) + reg.Register(firstReq, firstRes) + reg.Register(secondReq, secondRes) + }, + }, + { + name: "collect full and partial pages under total number of matching search results", + query: Query{ + Keywords: []string{"keyword"}, + Kind: "commits", + Limit: 110, + Order: "desc", + Sort: "committer-date", + Qualifiers: Qualifiers{ + Author: "foobar", + CommitterDate: ">2021-02-28", }, - ) + }, + result: CommitsResult{ + IncompleteResults: false, + Items: initialize(0, 110, func(i int) Commit { + return Commit{ + Sha: strconv.Itoa(i), + } + }), + Total: 287, + }, + httpStubs: func(reg *httpmock.Registry) { + firstReq := httpmock.QueryMatcher("GET", "search/commits", url.Values{ + "page": []string{"1"}, + "per_page": []string{"100"}, + "order": []string{"desc"}, + "sort": []string{"committer-date"}, + "q": []string{"keyword author:foobar committer-date:>2021-02-28"}, + }) + firstRes := httpmock.JSONResponse(map[string]interface{}{ + "incomplete_results": false, + "total_count": 287, + "items": initialize(0, 100, func(i int) map[string]interface{} { + return map[string]interface{}{ + "sha": strconv.Itoa(i), + } + }), + }) firstRes = httpmock.WithHeader(firstRes, "Link", `; rel="next"`) secondReq := httpmock.QueryMatcher("GET", "search/commits", url.Values{ "page": []string{"2"}, - "per_page": []string{"29"}, + "per_page": []string{"100"}, "order": []string{"desc"}, "sort": []string{"committer-date"}, "q": []string{"keyword author:foobar committer-date:>2021-02-28"}, - }, - ) - secondRes := httpmock.JSONResponse(CommitsResult{ - IncompleteResults: false, - Items: []Commit{{Sha: "def"}}, - Total: 2, - }, - ) + }) + secondRes := httpmock.JSONResponse(map[string]interface{}{ + "incomplete_results": false, + "total_count": 287, + "items": initialize(100, 200, func(i int) map[string]interface{} { + return map[string]interface{}{ + "sha": strconv.Itoa(i), + } + }), + }) reg.Register(firstReq, firstRes) reg.Register(secondReq, secondRes) }, @@ -269,8 +410,8 @@ func TestSearcherCommits(t *testing.T) { query: query, wantErr: true, errMsg: heredoc.Doc(` - Invalid search query "keyword author:foobar committer-date:>2021-02-28". - "blah" is not a recognized date/time format. Please provide an ISO 8601 date/time value, such as YYYY-MM-DD.`), + Invalid search query "keyword author:foobar committer-date:>2021-02-28". + "blah" is not a recognized date/time format. Please provide an ISO 8601 date/time value, such as YYYY-MM-DD.`), httpStubs: func(reg *httpmock.Registry) { reg.Register( httpmock.QueryMatcher("GET", "search/commits", values), @@ -413,15 +554,14 @@ func TestSearcherRepositories(t *testing.T) { }, }, }) - firstRes = httpmock.WithHeader(firstRes, "Link", `; rel="next"`) + firstRes = httpmock.WithHeader(firstRes, "Link", `; rel="next"`) secondReq := httpmock.QueryMatcher("GET", "search/repositories", url.Values{ "page": []string{"2"}, - "per_page": []string{"29"}, + "per_page": []string{"30"}, "order": []string{"desc"}, "sort": []string{"stars"}, "q": []string{"keyword stars:>=5 topic:topic"}, - }, - ) + }) secondRes := httpmock.JSONResponse(map[string]interface{}{ "incomplete_results": false, "total_count": 2, @@ -435,13 +575,73 @@ func TestSearcherRepositories(t *testing.T) { reg.Register(secondReq, secondRes) }, }, + { + name: "collect full and partial pages under total number of matching search results", + query: Query{ + Keywords: []string{"keyword"}, + Kind: "repositories", + Limit: 110, + Order: "desc", + Sort: "stars", + Qualifiers: Qualifiers{ + Stars: ">=5", + Topic: []string{"topic"}, + }, + }, + result: RepositoriesResult{ + IncompleteResults: false, + Items: initialize(0, 110, func(i int) Repository { + return Repository{ + Name: fmt.Sprintf("name%d", i), + } + }), + Total: 287, + }, + httpStubs: func(reg *httpmock.Registry) { + firstReq := httpmock.QueryMatcher("GET", "search/repositories", url.Values{ + "page": []string{"1"}, + "per_page": []string{"100"}, + "order": []string{"desc"}, + "sort": []string{"stars"}, + "q": []string{"keyword stars:>=5 topic:topic"}, + }) + firstRes := httpmock.JSONResponse(map[string]interface{}{ + "incomplete_results": false, + "total_count": 287, + "items": initialize(0, 100, func(i int) interface{} { + return map[string]interface{}{ + "name": fmt.Sprintf("name%d", i), + } + }), + }) + firstRes = httpmock.WithHeader(firstRes, "Link", `; rel="next"`) + secondReq := httpmock.QueryMatcher("GET", "search/repositories", url.Values{ + "page": []string{"2"}, + "per_page": []string{"100"}, + "order": []string{"desc"}, + "sort": []string{"stars"}, + "q": []string{"keyword stars:>=5 topic:topic"}, + }) + secondRes := httpmock.JSONResponse(map[string]interface{}{ + "incomplete_results": false, + "total_count": 287, + "items": initialize(100, 200, func(i int) interface{} { + return map[string]interface{}{ + "name": fmt.Sprintf("name%d", i), + } + }), + }) + reg.Register(firstReq, firstRes) + reg.Register(secondReq, secondRes) + }, + }, { name: "handles search errors", query: query, wantErr: true, errMsg: heredoc.Doc(` - Invalid search query "keyword stars:>=5 topic:topic". - "blah" is not a recognized date/time format. Please provide an ISO 8601 date/time value, such as YYYY-MM-DD.`), + Invalid search query "keyword stars:>=5 topic:topic". + "blah" is not a recognized date/time format. Please provide an ISO 8601 date/time value, such as YYYY-MM-DD.`), httpStubs: func(reg *httpmock.Registry) { reg.Register( httpmock.QueryMatcher("GET", "search/repositories", values), @@ -529,10 +729,14 @@ func TestSearcherIssues(t *testing.T) { httpStubs: func(reg *httpmock.Registry) { reg.Register( httpmock.QueryMatcher("GET", "search/issues", values), - httpmock.JSONResponse(IssuesResult{ - IncompleteResults: false, - Items: []Issue{{Number: 1234}}, - Total: 1, + httpmock.JSONResponse(map[string]interface{}{ + "incomplete_results": false, + "total_count": 1, + "items": []interface{}{ + map[string]interface{}{ + "number": 1234, + }, + }, }), ) }, @@ -549,10 +753,14 @@ func TestSearcherIssues(t *testing.T) { httpStubs: func(reg *httpmock.Registry) { reg.Register( httpmock.QueryMatcher("GET", "api/v3/search/issues", values), - httpmock.JSONResponse(IssuesResult{ - IncompleteResults: false, - Items: []Issue{{Number: 1234}}, - Total: 1, + httpmock.JSONResponse(map[string]interface{}{ + "incomplete_results": false, + "total_count": 1, + "items": []interface{}{ + map[string]interface{}{ + "number": 1234, + }, + }, }), ) }, @@ -567,27 +775,92 @@ func TestSearcherIssues(t *testing.T) { }, httpStubs: func(reg *httpmock.Registry) { firstReq := httpmock.QueryMatcher("GET", "search/issues", values) - firstRes := httpmock.JSONResponse(IssuesResult{ - IncompleteResults: false, - Items: []Issue{{Number: 1234}}, - Total: 2, + firstRes := httpmock.JSONResponse(map[string]interface{}{ + "incomplete_results": false, + "total_count": 2, + "items": []interface{}{ + map[string]interface{}{ + "number": 1234, + }, + }, + }) + firstRes = httpmock.WithHeader(firstRes, "Link", `; rel="next"`) + secondReq := httpmock.QueryMatcher("GET", "search/issues", url.Values{ + "page": []string{"2"}, + "per_page": []string{"30"}, + "order": []string{"desc"}, + "sort": []string{"comments"}, + "q": []string{"keyword is:locked is:public language:go"}, + }) + secondRes := httpmock.JSONResponse(map[string]interface{}{ + "incomplete_results": false, + "total_count": 2, + "items": []interface{}{ + map[string]interface{}{ + "number": 5678, + }, + }, + }) + reg.Register(firstReq, firstRes) + reg.Register(secondReq, secondRes) + }, + }, + { + name: "collect full and partial pages under total number of matching search results", + query: Query{ + Keywords: []string{"keyword"}, + Kind: "issues", + Limit: 110, + Order: "desc", + Sort: "comments", + Qualifiers: Qualifiers{ + Language: "go", + Is: []string{"public", "locked"}, }, - ) + }, + result: IssuesResult{ + IncompleteResults: false, + Items: initialize(0, 110, func(i int) Issue { + return Issue{ + Number: i, + } + }), + Total: 287, + }, + httpStubs: func(reg *httpmock.Registry) { + firstReq := httpmock.QueryMatcher("GET", "search/issues", url.Values{ + "page": []string{"1"}, + "per_page": []string{"100"}, + "order": []string{"desc"}, + "sort": []string{"comments"}, + "q": []string{"keyword is:locked is:public language:go"}, + }) + firstRes := httpmock.JSONResponse(map[string]interface{}{ + "incomplete_results": false, + "total_count": 287, + "items": initialize(0, 100, func(i int) interface{} { + return map[string]interface{}{ + "number": i, + } + }), + }) firstRes = httpmock.WithHeader(firstRes, "Link", `; rel="next"`) secondReq := httpmock.QueryMatcher("GET", "search/issues", url.Values{ "page": []string{"2"}, - "per_page": []string{"29"}, + "per_page": []string{"100"}, "order": []string{"desc"}, "sort": []string{"comments"}, "q": []string{"keyword is:locked is:public language:go"}, - }, - ) - secondRes := httpmock.JSONResponse(IssuesResult{ - IncompleteResults: false, - Items: []Issue{{Number: 5678}}, - Total: 2, - }, - ) + }) + secondRes := httpmock.JSONResponse(map[string]interface{}{ + "incomplete_results": false, + "total_count": 287, + "items": initialize(100, 200, func(i int) interface{} { + return map[string]interface{}{ + "number": i, + } + }), + }) reg.Register(firstReq, firstRes) reg.Register(secondReq, secondRes) }, @@ -597,8 +870,8 @@ func TestSearcherIssues(t *testing.T) { query: query, wantErr: true, errMsg: heredoc.Doc(` - Invalid search query "keyword is:locked is:public language:go". - "blah" is not a recognized date/time format. Please provide an ISO 8601 date/time value, such as YYYY-MM-DD.`), + Invalid search query "keyword is:locked is:public language:go". + "blah" is not a recognized date/time format. Please provide an ISO 8601 date/time value, such as YYYY-MM-DD.`), httpStubs: func(reg *httpmock.Registry) { reg.Register( httpmock.QueryMatcher("GET", "search/issues", values), @@ -686,3 +959,12 @@ func TestSearcherURL(t *testing.T) { }) } } + +// initialize generate slices over a range for test scenarios using the provided initializer. +func initialize[T any](start int, stop int, initializer func(i int) T) []T { + results := make([]T, 0, (stop - start)) + for i := start; i < stop; i++ { + results = append(results, initializer(i)) + } + return results +}