From b4b8e8a13bb99bae479929048ea5f64808d0bc32 Mon Sep 17 00:00:00 2001 From: Andrew Kaster Date: Tue, 3 Dec 2024 21:43:28 -0700 Subject: [PATCH] Allow argument to `gh repo set-default` to be the name of a remote --- pkg/cmd/repo/setdefault/setdefault.go | 35 +++++-- pkg/cmd/repo/setdefault/setdefault_test.go | 114 ++++++++++++++++----- 2 files changed, 115 insertions(+), 34 deletions(-) diff --git a/pkg/cmd/repo/setdefault/setdefault.go b/pkg/cmd/repo/setdefault/setdefault.go index eb8fcfb5a57..3e4ea13dd8d 100644 --- a/pkg/cmd/repo/setdefault/setdefault.go +++ b/pkg/cmd/repo/setdefault/setdefault.go @@ -78,14 +78,6 @@ func NewCmdSetDefault(f *cmdutil.Factory, runF func(*SetDefaultOptions) error) * `), Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - if len(args) > 0 { - var err error - opts.Repo, err = ghrepo.FromFullName(args[0]) - if err != nil { - return err - } - } - if !opts.ViewMode && !opts.IO.CanPrompt() && opts.Repo == nil { return cmdutil.FlagErrorf("repository required when not running interactively") } @@ -96,6 +88,14 @@ func NewCmdSetDefault(f *cmdutil.Factory, runF func(*SetDefaultOptions) error) * return errors.New("must be run from inside a git repository") } + if len(args) > 0 { + var err error + opts.Repo, err = parseRepo(args[0], opts) + if err != nil { + return err + } + } + if runF != nil { return runF(opts) } @@ -252,3 +252,22 @@ func displayRemoteRepoName(remote *context.Remote) string { return ghrepo.FullName(repo) } + +func parseRepo(name string, opts *SetDefaultOptions) (ghrepo.Interface, error) { + if repo, err := ghrepo.FromFullName(name); err == nil { + return repo, nil + } + + remotes, err := opts.Remotes() + if err != nil { + return nil, err + } + + for _, remote := range remotes { + if remote.Name == name { + return remote.Repo, nil + } + } + + return nil, fmt.Errorf(`expected the "[HOST/]OWNER/REPO" format or a remote name, got %q`, name) +} diff --git a/pkg/cmd/repo/setdefault/setdefault_test.go b/pkg/cmd/repo/setdefault/setdefault_test.go index 55a0193d593..27bf160d9e8 100644 --- a/pkg/cmd/repo/setdefault/setdefault_test.go +++ b/pkg/cmd/repo/setdefault/setdefault_test.go @@ -21,6 +21,7 @@ func TestNewCmdSetDefault(t *testing.T) { tests := []struct { name string gitStubs func(*run.CommandStubber) + remotes []*context.Remote input string output SetDefaultOptions wantErr bool @@ -31,45 +32,104 @@ func TestNewCmdSetDefault(t *testing.T) { gitStubs: func(cs *run.CommandStubber) { cs.Register(`git rev-parse --git-dir`, 0, ".git") }, - input: "", - output: SetDefaultOptions{}, + remotes: []*context.Remote{}, + input: "", + output: SetDefaultOptions{}, }, { name: "repo argument", gitStubs: func(cs *run.CommandStubber) { cs.Register(`git rev-parse --git-dir`, 0, ".git") }, - input: "cli/cli", - output: SetDefaultOptions{Repo: ghrepo.New("cli", "cli")}, + remotes: []*context.Remote{}, + input: "cli/cli", + output: SetDefaultOptions{Repo: ghrepo.New("cli", "cli")}, + }, + { + name: "invalid repo argument", + gitStubs: func(cs *run.CommandStubber) { + cs.Register(`git rev-parse --git-dir`, 0, ".git") + }, + remotes: []*context.Remote{}, + input: "some_invalid_format", + wantErr: true, + errMsg: `expected the "[HOST/]OWNER/REPO" format or a remote name, got "some_invalid_format"`, }, { - name: "invalid repo argument", - gitStubs: func(cs *run.CommandStubber) {}, - input: "some_invalid_format", - wantErr: true, - errMsg: `expected the "[HOST/]OWNER/REPO" format, got "some_invalid_format"`, + name: "repo argument is first remote name", + gitStubs: func(cs *run.CommandStubber) { + cs.Register(`git rev-parse --git-dir`, 0, ".git") + }, + remotes: []*context.Remote{ + { + Remote: &git.Remote{Name: "origin"}, + Repo: ghrepo.New("OWNER", "REPO"), + }, + }, + input: "origin", + output: SetDefaultOptions{Repo: ghrepo.New("OWNER", "REPO")}, + }, + { + name: "repo argument is arbitrary remote name", + gitStubs: func(cs *run.CommandStubber) { + cs.Register(`git rev-parse --git-dir`, 0, ".git") + }, + remotes: []*context.Remote{ + { + Remote: &git.Remote{Name: "origin"}, + Repo: ghrepo.New("OWNER", "REPO"), + }, + { + Remote: &git.Remote{Name: "upstream"}, + Repo: ghrepo.New("OWNER2", "REPO2"), + }, + { + Remote: &git.Remote{Name: "other"}, + Repo: ghrepo.New("OWNER3", "REPO3"), + }, + }, + input: "upstream", + output: SetDefaultOptions{Repo: ghrepo.New("OWNER2", "REPO2")}, + }, + { + name: "repo argument is a remote name, but no such remote exists", + gitStubs: func(cs *run.CommandStubber) { + cs.Register(`git rev-parse --git-dir`, 0, ".git") + }, + remotes: []*context.Remote{ + { + Remote: &git.Remote{Name: "origin"}, + Repo: ghrepo.New("OWNER", "REPO"), + }, + }, + input: "upstream", + wantErr: true, + errMsg: `expected the "[HOST/]OWNER/REPO" format or a remote name, got "upstream"`, }, { name: "view flag", gitStubs: func(cs *run.CommandStubber) { cs.Register(`git rev-parse --git-dir`, 0, ".git") }, - input: "--view", - output: SetDefaultOptions{ViewMode: true}, + remotes: []*context.Remote{}, + input: "--view", + output: SetDefaultOptions{ViewMode: true}, }, { name: "unset flag", gitStubs: func(cs *run.CommandStubber) { cs.Register(`git rev-parse --git-dir`, 0, ".git") }, - input: "--unset", - output: SetDefaultOptions{UnsetMode: true}, + input: "--unset", + remotes: []*context.Remote{}, + output: SetDefaultOptions{UnsetMode: true}, }, { name: "run from non-git directory", gitStubs: func(cs *run.CommandStubber) { cs.Register(`git rev-parse --git-dir`, 128, "") }, + remotes: []*context.Remote{}, input: "", wantErr: true, errMsg: "must be run from inside a git repository", @@ -81,24 +141,26 @@ func TestNewCmdSetDefault(t *testing.T) { io.SetStdoutTTY(true) io.SetStdinTTY(true) io.SetStderrTTY(true) - f := &cmdutil.Factory{ - IOStreams: io, - GitClient: &git.Client{GitPath: "/fake/path/to/git"}, - } - - var gotOpts *SetDefaultOptions - cmd := NewCmdSetDefault(f, func(opts *SetDefaultOptions) error { - gotOpts = opts - return nil - }) - cmd.SetIn(&bytes.Buffer{}) - cmd.SetOut(&bytes.Buffer{}) - cmd.SetErr(&bytes.Buffer{}) t.Run(tt.name, func(t *testing.T) { argv, err := shlex.Split(tt.input) assert.NoError(t, err) + f := &cmdutil.Factory{ + IOStreams: io, + GitClient: &git.Client{GitPath: "/fake/path/to/git"}, + Remotes: func() (context.Remotes, error) { return tt.remotes, nil }, + } + + var gotOpts *SetDefaultOptions + cmd := NewCmdSetDefault(f, func(opts *SetDefaultOptions) error { + gotOpts = opts + return nil + }) + cmd.SetIn(&bytes.Buffer{}) + cmd.SetOut(&bytes.Buffer{}) + cmd.SetErr(&bytes.Buffer{}) + cmd.SetArgs(argv) cs, teardown := run.Stub()