diff --git a/command.go b/command.go index d1a9d0575d..d1fc008e78 100644 --- a/command.go +++ b/command.go @@ -303,6 +303,9 @@ func (cmd *Command) appendFlag(fl Flag) { // VisiblePersistentFlags returns a slice of [LocalFlag] with Persistent=true and Hidden=false. func (cmd *Command) VisiblePersistentFlags() []Flag { + if cmd.isCompletionCommand { + return nil + } var flags []Flag for _, fl := range cmd.Root().Flags { pfl, ok := fl.(LocalFlag) diff --git a/command_run.go b/command_run.go index 269dd85f2e..87fa7e7fa5 100644 --- a/command_run.go +++ b/command_run.go @@ -141,12 +141,6 @@ func (cmd *Command) run(ctx context.Context, osArgs []string) (_ context.Context var rargs Args = &stringSliceArgs{v: osArgs} var args Args = &stringSliceArgs{rargs.Tail()} - if cmd.isCompletionCommand { - tracef("completion command detected, skipping pre-parse (cmd=%[1]q)", cmd.Name) - cmd.parsedArgs = args - return ctx, cmd.Action(ctx, cmd) - } - for _, f := range cmd.allFlags() { if err := f.PreParse(); err != nil { return ctx, err @@ -268,13 +262,12 @@ func (cmd *Command) run(ctx context.Context, osArgs []string) (_ context.Context subCmd = cmd.Command(name) if subCmd == nil { hasDefault := cmd.DefaultCommand != "" - isFlagName := slices.Contains(cmd.FlagNames(), name) if hasDefault { tracef("using default command=%[1]q (cmd=%[2]q)", cmd.DefaultCommand, cmd.Name) } - if isFlagName || hasDefault { + if hasDefault { argsWithDefault := cmd.argsWithDefaultCommand(cmd.parsedArgs) tracef("using default command args=%[1]q (cmd=%[2]q)", argsWithDefault, cmd.Name) subCmd = cmd.Command(argsWithDefault.First()) diff --git a/command_test.go b/command_test.go index 72fa6a6628..9d77c12f23 100644 --- a/command_test.go +++ b/command_test.go @@ -5683,6 +5683,53 @@ func TestEmptyPositionalArgs(t *testing.T) { } } +// Regression for #2234: an empty positional arg following a flag used to be +// dropped along with every arg after it. +func TestEmptyPositionalArgsAfterFlag(t *testing.T) { + testCases := []struct { + Name string + Args []string + ExpectedArgs []string + ExpectedFlag string + }{ + { + Name: "empty arg after equals-form flag", + Args: []string{"app", "-f=something", "", "arg2", "arg3"}, + ExpectedArgs: []string{"", "arg2", "arg3"}, + ExpectedFlag: "something", + }, + { + Name: "empty arg after space-form flag", + Args: []string{"app", "-f", "something", "", "arg2"}, + ExpectedArgs: []string{"", "arg2"}, + ExpectedFlag: "something", + }, + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + var args []string + var flagVal string + + cmd := &Command{ + Flags: []Flag{ + &StringFlag{Name: "f"}, + }, + Action: func(_ context.Context, cmd *Command) error { + args = cmd.Args().Slice() + flagVal = cmd.String("f") + return nil + }, + } + + err := cmd.Run(buildTestContext(t), tc.Args) + assert.NoError(t, err) + assert.Equal(t, tc.ExpectedArgs, args) + assert.Equal(t, tc.ExpectedFlag, flagVal) + }) + } +} + func TestFlagEqualsEmptyValue(t *testing.T) { t.Run("--flag= sets empty string", func(t *testing.T) { var val string @@ -5724,3 +5771,33 @@ func TestFlagEqualsEmptyValue(t *testing.T) { assert.Equal(t, []string{"positional"}, args) }) } + +// TestCommand_NoDefaultCmdArgMatchingFlag tests the argument set +// of a command which has no default command, and has a flag with +// the name of the next argument +func TestCommand_NoDefaultCmdArgMatchingFlag(t *testing.T) { + expectedArgs := stringSliceArgs{v: []string{"flag"}} + var actualArgs Args + cmd := &Command{ + Name: "rootCommand", + Flags: []Flag{ + &StringFlag{ + Name: "flag", + }, + }, + Commands: []*Command{ + { + Name: "subCommand", + }, + }, + Action: func(ctx context.Context, c *Command) error { + actualArgs = c.Args() + return nil + }, + } + // the last element - "flag" - is an argument sharing the same name as the flag + // "flag" of the rootCommand command + err := cmd.Run(buildTestContext(t), []string{"rootCommand", "--flag", "flagvalue", "flag"}) + require.NoError(t, err) + require.Equal(t, &expectedArgs, actualArgs) +} diff --git a/completion.go b/completion.go index de11edb406..167b0c0123 100644 --- a/completion.go +++ b/completion.go @@ -4,7 +4,6 @@ import ( "context" "embed" "fmt" - "sort" "strings" ) @@ -58,45 +57,36 @@ Output the script to path/to/autocomplete/$COMMAND.ps1 an run it. ` func buildCompletionCommand(appName string) *Command { - return &Command{ - Name: completionCommandName, - Hidden: true, - Usage: "Output shell completion script for bash, zsh, fish, or Powershell", - Description: strings.ReplaceAll(completionDescription, "$COMMAND", appName), - Action: func(ctx context.Context, cmd *Command) error { - return printShellCompletion(ctx, cmd, appName) - }, + cmd := &Command{ + Name: completionCommandName, + Hidden: true, + Usage: "Output shell completion script for bash, zsh, fish, or Powershell", + Description: strings.ReplaceAll(completionDescription, "$COMMAND", appName), isCompletionCommand: true, } -} - -func printShellCompletion(_ context.Context, cmd *Command, appName string) error { - var shells []string - for k := range shellCompletions { - shells = append(shells, k) - } - - sort.Strings(shells) - - if cmd.Args().Len() == 0 { - return Exit(fmt.Sprintf("no shell provided for completion command. available shells are %+v", shells), 1) - } - s := cmd.Args().First() - renderCompletion, ok := shellCompletions[s] - if !ok { - return Exit(fmt.Sprintf("unknown shell %s, available shells are %+v", s, shells), 1) + for shell, render := range shellCompletions { + cmd.Commands = append(cmd.Commands, buildShellCompletionSubcommand(shell, render, appName)) } - completionScript, err := renderCompletion(cmd, appName) - if err != nil { - return Exit(err, 1) - } + return cmd +} - _, err = cmd.Writer.Write([]byte(completionScript)) - if err != nil { - return Exit(err, 1) +func buildShellCompletionSubcommand(shell string, render renderCompletion, appName string) *Command { + return &Command{ + Name: shell, + Usage: fmt.Sprintf("Output %s completion script", shell), + isCompletionCommand: true, + Action: func(ctx context.Context, cmd *Command) error { + completionScript, err := render(cmd, appName) + if err != nil { + return Exit(err, 1) + } + _, err = cmd.Root().Writer.Write([]byte(completionScript)) + if err != nil { + return Exit(err, 1) + } + return nil + }, } - - return nil } diff --git a/completion_test.go b/completion_test.go index 889bfdcf59..3349701f53 100644 --- a/completion_test.go +++ b/completion_test.go @@ -11,6 +11,53 @@ import ( "github.com/stretchr/testify/require" ) +func TestCompletionHelp(t *testing.T) { + tests := []struct { + name string + args []string + }{ + { + name: "short help flag", + args: []string{"foo", completionCommandName, "-h"}, + }, + { + name: "long help flag", + args: []string{"foo", completionCommandName, "--help"}, + }, + { + name: "completion bash short help flag", + args: []string{"foo", completionCommandName, "bash", "-h"}, + }, + { + name: "completion bash long help flag", + args: []string{"foo", completionCommandName, "bash", "--help"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + out := &bytes.Buffer{} + + cmd := &Command{ + EnableShellCompletion: true, + Writer: out, + Flags: []Flag{ + &StringFlag{ + Name: "required-flag", + Required: true, + }, + }, + } + + r := require.New(t) + + r.NoError(cmd.Run(buildTestContext(t), test.args)) + r.Contains(out.String(), "USAGE") + r.NotContains(out.String(), "GLOBAL OPTIONS") + }) + } +} + func TestCompletionDisable(t *testing.T) { cmd := &Command{} @@ -19,8 +66,11 @@ func TestCompletionDisable(t *testing.T) { } func TestCompletionEnable(t *testing.T) { + out := &bytes.Buffer{} + cmd := &Command{ EnableShellCompletion: true, + Writer: out, Flags: []Flag{ &StringFlag{ Name: "goo", @@ -29,18 +79,23 @@ func TestCompletionEnable(t *testing.T) { }, } - err := cmd.Run(buildTestContext(t), []string{"foo", completionCommandName}) - assert.ErrorContains(t, err, "no shell provided") + r := require.New(t) + r.NoError(cmd.Run(buildTestContext(t), []string{"foo", completionCommandName})) + r.Contains(out.String(), "USAGE") } func TestCompletionEnableDiffCommandName(t *testing.T) { + out := &bytes.Buffer{} + cmd := &Command{ EnableShellCompletion: true, ShellCompletionCommandName: "junky", + Writer: out, } - err := cmd.Run(buildTestContext(t), []string{"foo", "junky"}) - assert.ErrorContains(t, err, "no shell provided") + r := require.New(t) + r.NoError(cmd.Run(buildTestContext(t), []string{"foo", "junky"})) + r.Contains(out.String(), "USAGE") } func TestCompletionShell(t *testing.T) { @@ -56,10 +111,7 @@ func TestCompletionShell(t *testing.T) { r := require.New(t) r.NoError(cmd.Run(buildTestContext(t), []string{"foo", completionCommandName, k})) - r.Containsf( - k, out.String(), - "Expected output to contain shell name %[1]q", k, - ) + r.NotEmpty(out.String(), "Expected non-empty completion output for shell %q", k) }) } } @@ -255,17 +307,6 @@ func TestCompletionSubcommand(t *testing.T) { } } -type mockWriter struct { - err error -} - -func (mw *mockWriter) Write(p []byte) (int, error) { - if mw.err != nil { - return 0, mw.err - } - return len(p), nil -} - func TestCompletionInvalidShell(t *testing.T) { cmd := &Command{ EnableShellCompletion: true, @@ -273,7 +314,11 @@ func TestCompletionInvalidShell(t *testing.T) { unknownShellName := "junky-sheell" err := cmd.Run(buildTestContext(t), []string{"foo", completionCommandName, unknownShellName}) - assert.ErrorContains(t, err, "unknown shell junky-sheell") + assert.ErrorContains(t, err, fmt.Sprintf("No help topic for '%s'", unknownShellName)) +} + +func TestCompletionShellRenderError(t *testing.T) { + unknownShellName := "junky-sheell" enableError := true shellCompletions[unknownShellName] = func(c *Command, appName string) (string, error) { @@ -286,16 +331,39 @@ func TestCompletionInvalidShell(t *testing.T) { delete(shellCompletions, unknownShellName) }() - err = cmd.Run(buildTestContext(t), []string{"foo", completionCommandName, unknownShellName}) + cmd := &Command{ + EnableShellCompletion: true, + } + + err := cmd.Run(buildTestContext(t), []string{"foo", completionCommandName, unknownShellName}) assert.ErrorContains(t, err, "cant do completion") +} - // now disable shell completion error - enableError = false - c := cmd.Command(completionCommandName) - assert.NotNil(t, c) - c.Writer = &mockWriter{ - err: fmt.Errorf("writer error"), +type mockWriter struct { + err error +} + +func (mw *mockWriter) Write(p []byte) (int, error) { + if mw.err != nil { + return 0, mw.err + } + return len(p), nil +} + +func TestCompletionShellWriteError(t *testing.T) { + shellName := "mock-shell" + shellCompletions[shellName] = func(c *Command, appName string) (string, error) { + return "something", nil } - err = cmd.Run(buildTestContext(t), []string{"foo", completionCommandName, unknownShellName}) + defer func() { + delete(shellCompletions, shellName) + }() + + cmd := &Command{ + EnableShellCompletion: true, + Writer: &mockWriter{err: fmt.Errorf("writer error")}, + } + + err := cmd.Run(buildTestContext(t), []string{"foo", completionCommandName, shellName}) assert.ErrorContains(t, err, "writer error") } diff --git a/mkdocs-requirements.txt b/mkdocs-requirements.txt index d7f1b2c694..4e68d7d375 100644 --- a/mkdocs-requirements.txt +++ b/mkdocs-requirements.txt @@ -1,4 +1,4 @@ -mkdocs-git-revision-date-localized-plugin==1.5.1 +mkdocs-git-revision-date-localized-plugin==1.5.2 mkdocs-material==9.7.6 mkdocs==1.6.1 mkdocs-redirects==1.2.3