diff --git a/command.go b/command.go index 8b99e26..9cf25ea 100644 --- a/command.go +++ b/command.go @@ -59,6 +59,12 @@ type Command struct { Middleware MiddlewareFunc Handler HandlerFunc HelpHandler HandlerFunc + // CompletionHandler is called when the command is run in completion + // mode. If nil, only the default completion handler is used. + // + // Flag and option parsing is best-effort in this mode, so even if an Option + // is "required" it may not be set. + CompletionHandler CompletionHandlerFunc } // AddSubcommands adds the given subcommands, setting their @@ -193,15 +199,22 @@ type Invocation struct { ctx context.Context Command *Command parsedFlags *pflag.FlagSet - Args []string + + // Args is reduced into the remaining arguments after parsing flags + // during Run. + Args []string + // Environ is a list of environment variables. Use EnvsWithPrefix to parse // os.Environ. Environ Environ Stdout io.Writer Stderr io.Writer Stdin io.Reader - Logger slog.Logger - Net Net + + // Deprecated + Logger slog.Logger + // Deprecated + Net Net // testing signalNotifyContext func(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) @@ -282,6 +295,23 @@ func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet { return fs2 } +func (inv *Invocation) CurWords() (prev string, cur string) { + switch len(inv.Args) { + // All the shells we support will supply at least one argument (empty string), + // but we don't want to panic. + case 0: + cur = "" + prev = "" + case 1: + cur = inv.Args[0] + prev = "" + default: + cur = inv.Args[len(inv.Args)-1] + prev = inv.Args[len(inv.Args)-2] + } + return +} + // run recursively executes the command and its children. // allArgs is wired through the stack so that global flags can be accepted // anywhere in the command invocation. @@ -378,8 +408,19 @@ func (inv *Invocation) run(state *runState) error { } } + // Outputted completions are not filtered based on the word under the cursor, as every shell we support does this already. + // We only look at the current word to figure out handler to run, or what directory to inspect. + if inv.IsCompletionMode() { + for _, e := range inv.complete() { + fmt.Fprintln(inv.Stdout, e) + } + return nil + } + + ignoreFlagParseErrors := inv.Command.RawArgs + // Flag parse errors are irrelevant for raw args commands. - if !inv.Command.RawArgs && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) { + if !ignoreFlagParseErrors && state.flagParseErr != nil && !errors.Is(state.flagParseErr, pflag.ErrHelp) { return xerrors.Errorf( "parsing flags (%v) for %q: %w", state.allArgs, @@ -392,11 +433,16 @@ func (inv *Invocation) run(state *runState) error { var missing []string for _, opt := range inv.Command.Options { if opt.Required && opt.ValueSource == ValueSourceNone { - missing = append(missing, opt.Flag) + name := opt.Name + // use flag as a fallback if name is empty + if name == "" { + name = opt.Flag + } + missing = append(missing, name) } } // Don't error for missing flags if `--help` was supplied. - if len(missing) > 0 && !errors.Is(state.flagParseErr, pflag.ErrHelp) { + if len(missing) > 0 && !inv.IsCompletionMode() && !errors.Is(state.flagParseErr, pflag.ErrHelp) { return xerrors.Errorf("Missing values for the required flags: %s", strings.Join(missing, ", ")) } @@ -433,7 +479,7 @@ func (inv *Invocation) run(state *runState) error { if inv.Command.Handler == nil || errors.Is(state.flagParseErr, pflag.ErrHelp) { if inv.Command.HelpHandler == nil { - return defaultHelpFn()(inv) + return DefaultHelpFn()(inv) } return inv.Command.HelpHandler(inv) } @@ -553,6 +599,69 @@ func (inv *Invocation) with(fn func(*Invocation)) *Invocation { return &i2 } +func (inv *Invocation) complete() []string { + prev, cur := inv.CurWords() + + // If the current word is a flag + if strings.HasPrefix(cur, "--") { + flagParts := strings.Split(cur, "=") + flagName := flagParts[0][2:] + // If it's an equals flag + if len(flagParts) == 2 { + if out := inv.completeFlag(flagName); out != nil { + for i, o := range out { + out[i] = fmt.Sprintf("--%s=%s", flagName, o) + } + return out + } + } else if out := inv.Command.Options.ByFlag(flagName); out != nil { + // If the current word is a valid flag, auto-complete it so the + // shell moves the cursor + return []string{cur} + } + } + // If the previous word is a flag, then we're writing it's value + // and we should check it's handler + if strings.HasPrefix(prev, "--") { + word := prev[2:] + if out := inv.completeFlag(word); out != nil { + return out + } + } + // If the current word is the command, move the shell cursor + if inv.Command.Name() == cur { + return []string{inv.Command.Name()} + } + var completions []string + + if inv.Command.CompletionHandler != nil { + completions = append(completions, inv.Command.CompletionHandler(inv)...) + } + + completions = append(completions, DefaultCompletionHandler(inv)...) + + return completions +} + +func (inv *Invocation) completeFlag(word string) []string { + opt := inv.Command.Options.ByFlag(word) + if opt == nil { + return nil + } + if opt.CompletionHandler != nil { + return opt.CompletionHandler(inv) + } + enum, ok := opt.Value.(*Enum) + if ok { + return enum.Choices + } + enumArr, ok := opt.Value.(*EnumArray) + if ok { + return enumArr.Choices + } + return nil +} + // MiddlewareFunc returns the next handler in the chain, // or nil if there are no more. type MiddlewareFunc func(next HandlerFunc) HandlerFunc @@ -637,3 +746,5 @@ func RequireRangeArgs(start, end int) MiddlewareFunc { // HandlerFunc handles an Invocation of a command. type HandlerFunc func(i *Invocation) error + +type CompletionHandlerFunc func(i *Invocation) []string diff --git a/command_test.go b/command_test.go index f6a20a2..e4fa951 100644 --- a/command_test.go +++ b/command_test.go @@ -12,6 +12,7 @@ import ( "golang.org/x/xerrors" serpent "github.com/coder/serpent" + "github.com/coder/serpent/completion" ) // ioBufs is the standard input, output, and error for a command. @@ -30,100 +31,153 @@ func fakeIO(i *serpent.Invocation) *ioBufs { return &b } -func TestCommand(t *testing.T) { - t.Parallel() - - cmd := func() *serpent.Command { - var ( - verbose bool - lower bool - prefix string - reqBool bool - reqStr string - ) - return &serpent.Command{ - Use: "root [subcommand]", - Options: serpent.OptionSet{ - serpent.Option{ - Name: "verbose", - Flag: "verbose", - Value: serpent.BoolOf(&verbose), - }, - serpent.Option{ - Name: "prefix", - Flag: "prefix", - Value: serpent.StringOf(&prefix), - }, +func sampleCommand(t *testing.T) *serpent.Command { + t.Helper() + var ( + verbose bool + lower bool + prefix string + reqBool bool + reqStr string + reqArr []string + reqEnumArr []string + fileArr []string + enumStr string + ) + enumChoices := []string{"foo", "bar", "qux"} + return &serpent.Command{ + Use: "root [subcommand]", + Options: serpent.OptionSet{ + serpent.Option{ + Name: "verbose", + Flag: "verbose", + Value: serpent.BoolOf(&verbose), }, - Children: []*serpent.Command{ - { - Use: "required-flag --req-bool=true --req-string=foo", - Short: "Example with required flags", - Options: serpent.OptionSet{ - serpent.Option{ - Name: "req-bool", - Flag: "req-bool", - Value: serpent.BoolOf(&reqBool), - Required: true, - }, - serpent.Option{ - Name: "req-string", - Flag: "req-string", - Value: serpent.Validate(serpent.StringOf(&reqStr), func(value *serpent.String) error { - ok := strings.Contains(value.String(), " ") - if !ok { - return xerrors.Errorf("string must contain a space") - } - return nil - }), - Required: true, - }, + serpent.Option{ + Name: "prefix", + Flag: "prefix", + Value: serpent.StringOf(&prefix), + }, + }, + Children: []*serpent.Command{ + { + Use: "required-flag --req-bool=true --req-string=foo", + Short: "Example with required flags", + Options: serpent.OptionSet{ + serpent.Option{ + Name: "req-bool", + Flag: "req-bool", + FlagShorthand: "b", + Value: serpent.BoolOf(&reqBool), + Required: true, }, - HelpHandler: func(i *serpent.Invocation) error { - _, _ = i.Stdout.Write([]byte("help text.png")) - return nil + serpent.Option{ + Name: "req-string", + Flag: "req-string", + FlagShorthand: "s", + Value: serpent.Validate(serpent.StringOf(&reqStr), func(value *serpent.String) error { + ok := strings.Contains(value.String(), " ") + if !ok { + return xerrors.Errorf("string must contain a space") + } + return nil + }), + Required: true, }, - Handler: func(i *serpent.Invocation) error { - _, _ = i.Stdout.Write([]byte(fmt.Sprintf("%s-%t", reqStr, reqBool))) - return nil + serpent.Option{ + Name: "req-enum", + Flag: "req-enum", + Value: serpent.EnumOf(&enumStr, enumChoices...), + }, + serpent.Option{ + Name: "req-array", + Flag: "req-array", + FlagShorthand: "a", + Value: serpent.StringArrayOf(&reqArr), + }, + serpent.Option{ + Name: "req-enum-array", + Flag: "req-enum-array", + Value: serpent.EnumArrayOf(&reqEnumArr, enumChoices...), }, }, - { - Use: "toupper [word]", - Short: "Converts a word to upper case", - Middleware: serpent.Chain( - serpent.RequireNArgs(1), - ), - Aliases: []string{"up"}, - Options: serpent.OptionSet{ - serpent.Option{ - Name: "lower", - Flag: "lower", - Value: serpent.BoolOf(&lower), - }, + HelpHandler: func(i *serpent.Invocation) error { + _, _ = i.Stdout.Write([]byte("help text.png")) + return nil + }, + Handler: func(i *serpent.Invocation) error { + _, _ = i.Stdout.Write([]byte(fmt.Sprintf("%s-%t", reqStr, reqBool))) + return nil + }, + }, + { + Use: "toupper [word]", + Short: "Converts a word to upper case", + Middleware: serpent.Chain( + serpent.RequireNArgs(1), + ), + Aliases: []string{"up"}, + Options: serpent.OptionSet{ + serpent.Option{ + Name: "lower", + Flag: "lower", + Value: serpent.BoolOf(&lower), }, - Handler: func(i *serpent.Invocation) error { - _, _ = i.Stdout.Write([]byte(prefix)) - w := i.Args[0] - if lower { - w = strings.ToLower(w) - } else { - w = strings.ToUpper(w) - } - _, _ = i.Stdout.Write( - []byte( - w, - ), - ) - if verbose { - _, _ = i.Stdout.Write([]byte("!!!")) - } - return nil + }, + Handler: func(i *serpent.Invocation) error { + _, _ = i.Stdout.Write([]byte(prefix)) + w := i.Args[0] + if lower { + w = strings.ToLower(w) + } else { + w = strings.ToUpper(w) + } + _, _ = i.Stdout.Write( + []byte( + w, + ), + ) + if verbose { + _, _ = i.Stdout.Write([]byte("!!!")) + } + return nil + }, + }, + { + Use: "file ", + Handler: func(inv *serpent.Invocation) error { + return nil + }, + CompletionHandler: completion.FileHandler(func(info os.FileInfo) bool { + return true + }), + Middleware: serpent.RequireNArgs(1), + }, + { + Use: "altfile", + Handler: func(inv *serpent.Invocation) error { + return nil + }, + Options: serpent.OptionSet{ + { + Name: "extra", + Flag: "extra", + Description: "Extra files.", + Value: serpent.StringArrayOf(&fileArr), }, }, + CompletionHandler: func(i *serpent.Invocation) []string { + return []string{"doesntexist.go"} + }, }, - } + }, } +} + +func TestCommand(t *testing.T) { + t.Parallel() + + cmd := func() *serpent.Command { return sampleCommand(t) } t.Run("SimpleOK", func(t *testing.T) { t.Parallel() @@ -498,11 +552,22 @@ func TestCommand_RootRaw(t *testing.T) { func TestCommand_HyphenHyphen(t *testing.T) { t.Parallel() + var verbose bool cmd := &serpent.Command{ Handler: (func(i *serpent.Invocation) error { _, _ = i.Stdout.Write([]byte(strings.Join(i.Args, " "))) + if verbose { + return xerrors.New("verbose should not be true because flag after --") + } return nil }), + Options: serpent.OptionSet{ + { + Name: "verbose", + Flag: "verbose", + Value: serpent.BoolOf(&verbose), + }, + }, } inv := cmd.Invoke("--", "--verbose", "--friendly") diff --git a/completion.go b/completion.go new file mode 100644 index 0000000..d82a06e --- /dev/null +++ b/completion.go @@ -0,0 +1,40 @@ +package serpent + +import ( + "strings" + + "github.com/spf13/pflag" +) + +// CompletionModeEnv is a special environment variable that is +// set when the command is being run in completion mode. +const CompletionModeEnv = "COMPLETION_MODE" + +// IsCompletionMode returns true if the command is being run in completion mode. +func (inv *Invocation) IsCompletionMode() bool { + _, ok := inv.Environ.Lookup(CompletionModeEnv) + return ok +} + +// DefaultCompletionHandler is a handler that prints all the subcommands, or +// all the options that haven't been exhaustively set, if the current word +// starts with a dash. +func DefaultCompletionHandler(inv *Invocation) []string { + _, cur := inv.CurWords() + var allResps []string + if strings.HasPrefix(cur, "-") { + for _, opt := range inv.Command.Options { + _, isSlice := opt.Value.(pflag.SliceValue) + if opt.ValueSource == ValueSourceNone || + opt.ValueSource == ValueSourceDefault || + isSlice { + allResps = append(allResps, "--"+opt.Flag) + } + } + return allResps + } + for _, cmd := range inv.Command.Children { + allResps = append(allResps, cmd.Name()) + } + return allResps +} diff --git a/completion/README.md b/completion/README.md new file mode 100644 index 0000000..d7021ed --- /dev/null +++ b/completion/README.md @@ -0,0 +1,11 @@ +# completion + +The `completion` package extends `serpent` to allow applications to generate rich auto-completions. + + +## Protocol + +The completion scripts call out to the serpent command to generate +completions. The convention is to pass the exact args and flags (or +cmdline) of the in-progress command with a `COMPLETION_MODE=1` environment variable. That environment variable lets the command know to generate completions instead of running the command. +By default, completions will be generated based on available flags and subcommands. Additional completions can be added by supplying a `CompletionHandlerFunc` on an Option or Command. \ No newline at end of file diff --git a/completion/all.go b/completion/all.go new file mode 100644 index 0000000..ca6e2cf --- /dev/null +++ b/completion/all.go @@ -0,0 +1,198 @@ +package completion + +import ( + "bytes" + "errors" + "fmt" + "io" + "io/fs" + "os" + "os/user" + "path/filepath" + "runtime" + "strings" + "text/template" + + "github.com/coder/serpent" + + "github.com/natefinch/atomic" +) + +const ( + completionStartTemplate = `# ============ BEGIN {{.Name}} COMPLETION ============` + completionEndTemplate = `# ============ END {{.Name}} COMPLETION ==============` +) + +type Shell interface { + Name() string + InstallPath() (string, error) + WriteCompletion(io.Writer) error + ProgramName() string +} + +const ( + ShellBash string = "bash" + ShellFish string = "fish" + ShellZsh string = "zsh" + ShellPowershell string = "powershell" +) + +func ShellByName(shell, programName string) (Shell, error) { + switch shell { + case ShellBash: + return Bash(runtime.GOOS, programName), nil + case ShellFish: + return Fish(runtime.GOOS, programName), nil + case ShellZsh: + return Zsh(runtime.GOOS, programName), nil + case ShellPowershell: + return Powershell(runtime.GOOS, programName), nil + default: + return nil, fmt.Errorf("unsupported shell %q", shell) + } +} + +func ShellOptions(choice *string) *serpent.Enum { + return serpent.EnumOf(choice, ShellBash, ShellFish, ShellZsh, ShellPowershell) +} + +func DetectUserShell(programName string) (Shell, error) { + // Attempt to get the SHELL environment variable first + if shell := os.Getenv("SHELL"); shell != "" { + return ShellByName(filepath.Base(shell), "") + } + + // Fallback: Look up the current user and parse /etc/passwd + currentUser, err := user.Current() + if err != nil { + return nil, err + } + + // Open and parse /etc/passwd + passwdFile, err := os.ReadFile("/etc/passwd") + if err != nil { + return nil, err + } + + lines := strings.Split(string(passwdFile), "\n") + for _, line := range lines { + if strings.HasPrefix(line, currentUser.Username+":") { + parts := strings.Split(line, ":") + if len(parts) > 6 { + return ShellByName(filepath.Base(parts[6]), programName) // The shell is typically the 7th field + } + } + } + + return nil, fmt.Errorf("default shell not found") +} + +func writeConfig( + w io.Writer, + cfgTemplate string, + programName string, +) error { + tmpl, err := template.New("script").Parse(cfgTemplate) + if err != nil { + return fmt.Errorf("parse template: %w", err) + } + + err = tmpl.Execute( + w, + map[string]string{ + "Name": programName, + }, + ) + if err != nil { + return fmt.Errorf("execute template: %w", err) + } + + return nil +} + +func InstallShellCompletion(shell Shell) error { + path, err := shell.InstallPath() + if err != nil { + return fmt.Errorf("get install path: %w", err) + } + var headerBuf bytes.Buffer + err = writeConfig(&headerBuf, completionStartTemplate, shell.ProgramName()) + if err != nil { + return fmt.Errorf("generate header: %w", err) + } + + var footerBytes bytes.Buffer + err = writeConfig(&footerBytes, completionEndTemplate, shell.ProgramName()) + if err != nil { + return fmt.Errorf("generate footer: %w", err) + } + + err = os.MkdirAll(filepath.Dir(path), 0o755) + if err != nil { + return fmt.Errorf("create directories: %w", err) + } + + f, err := os.ReadFile(path) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("read ssh config failed: %w", err) + } + + before, after, err := templateConfigSplit(headerBuf.Bytes(), footerBytes.Bytes(), f) + if err != nil { + return err + } + + outBuf := bytes.Buffer{} + _, _ = outBuf.Write(before) + if len(before) > 0 { + _, _ = outBuf.Write([]byte("\n")) + } + _, _ = outBuf.Write(headerBuf.Bytes()) + err = shell.WriteCompletion(&outBuf) + if err != nil { + return fmt.Errorf("generate completion: %w", err) + } + _, _ = outBuf.Write(footerBytes.Bytes()) + _, _ = outBuf.Write([]byte("\n")) + _, _ = outBuf.Write(after) + + err = atomic.WriteFile(path, &outBuf) + if err != nil { + return fmt.Errorf("write completion: %w", err) + } + + return nil +} + +func templateConfigSplit(header, footer, data []byte) (before, after []byte, err error) { + startCount := bytes.Count(data, header) + endCount := bytes.Count(data, footer) + if startCount > 1 || endCount > 1 { + return nil, nil, fmt.Errorf("Malformed config file: multiple config sections") + } + + startIndex := bytes.Index(data, header) + endIndex := bytes.Index(data, footer) + if startIndex == -1 && endIndex != -1 { + return data, nil, fmt.Errorf("Malformed config file: missing completion header") + } + if startIndex != -1 && endIndex == -1 { + return data, nil, fmt.Errorf("Malformed config file: missing completion footer") + } + if startIndex != -1 && endIndex != -1 { + if startIndex > endIndex { + return data, nil, fmt.Errorf("Malformed config file: completion header after footer") + } + // Include leading and trailing newline, if present + start := startIndex + if start > 0 { + start-- + } + end := endIndex + len(footer) + if end < len(data) { + end++ + } + return data[:start], data[end:], nil + } + return data, nil, nil +} diff --git a/completion/bash.go b/completion/bash.go new file mode 100644 index 0000000..14282eb --- /dev/null +++ b/completion/bash.go @@ -0,0 +1,62 @@ +package completion + +import ( + "io" + "path/filepath" + + home "github.com/mitchellh/go-homedir" +) + +type bash struct { + goos string + programName string +} + +var _ Shell = &bash{} + +func Bash(goos string, programName string) Shell { + return &bash{goos: goos, programName: programName} +} + +func (b *bash) Name() string { + return "bash" +} + +func (b *bash) InstallPath() (string, error) { + homeDir, err := home.Dir() + if err != nil { + return "", err + } + if b.goos == "darwin" { + return filepath.Join(homeDir, ".bash_profile"), nil + } + return filepath.Join(homeDir, ".bashrc"), nil +} + +func (b *bash) WriteCompletion(w io.Writer) error { + return writeConfig(w, bashCompletionTemplate, b.programName) +} + +func (b *bash) ProgramName() string { + return b.programName +} + +const bashCompletionTemplate = ` +_generate_{{.Name}}_completions() { + local args=("${COMP_WORDS[@]:1:COMP_CWORD}") + + declare -a output + mapfile -t output < <(COMPLETION_MODE=1 "{{.Name}}" "${args[@]}") + + declare -a completions + mapfile -t completions < <( compgen -W "$(printf '%q ' "${output[@]}")" -- "$2" ) + + local comp + COMPREPLY=() + for comp in "${completions[@]}"; do + COMPREPLY+=("$(printf "%q" "$comp")") + done +} +# Setup Bash to use the function for completions for '{{.Name}}' +complete -F _generate_{{.Name}}_completions {{.Name}} +` diff --git a/completion/fish.go b/completion/fish.go new file mode 100644 index 0000000..7e5a21e --- /dev/null +++ b/completion/fish.go @@ -0,0 +1,51 @@ +package completion + +import ( + "io" + "path/filepath" + + home "github.com/mitchellh/go-homedir" +) + +type fish struct { + goos string + programName string +} + +var _ Shell = &fish{} + +func Fish(goos string, programName string) Shell { + return &fish{goos: goos, programName: programName} +} + +func (f *fish) Name() string { + return "fish" +} + +func (f *fish) InstallPath() (string, error) { + homeDir, err := home.Dir() + if err != nil { + return "", err + } + return filepath.Join(homeDir, ".config/fish/completions/", f.programName+".fish"), nil +} + +func (f *fish) WriteCompletion(w io.Writer) error { + return writeConfig(w, fishCompletionTemplate, f.programName) +} + +func (f *fish) ProgramName() string { + return f.programName +} + +const fishCompletionTemplate = ` +function _{{.Name}}_completions + # Capture the full command line as an array + set -l args (commandline -opc) + set -l current (commandline -ct) + COMPLETION_MODE=1 $args $current +end + +# Setup Fish to use the function for completions for '{{.Name}}' +complete -c {{.Name}} -f -a '(_{{.Name}}_completions)' +` diff --git a/completion/handlers.go b/completion/handlers.go new file mode 100644 index 0000000..848bb06 --- /dev/null +++ b/completion/handlers.go @@ -0,0 +1,55 @@ +package completion + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/coder/serpent" +) + +// FileHandler returns a handler that completes file names, using the +// given filter func, which may be nil. +func FileHandler(filter func(info os.FileInfo) bool) serpent.CompletionHandlerFunc { + return func(inv *serpent.Invocation) []string { + var out []string + _, word := inv.CurWords() + + dir, _ := filepath.Split(word) + if dir == "" { + dir = "." + } + f, err := os.Open(dir) + if err != nil { + return out + } + defer f.Close() + if dir == "." { + dir = "" + } + + infos, err := f.Readdir(0) + if err != nil { + return out + } + + for _, info := range infos { + if filter != nil && !filter(info) { + continue + } + + var cur string + if info.IsDir() { + cur = fmt.Sprintf("%s%s%c", dir, info.Name(), os.PathSeparator) + } else { + cur = fmt.Sprintf("%s%s", dir, info.Name()) + } + + if strings.HasPrefix(cur, word) { + out = append(out, cur) + } + } + return out + } +} diff --git a/completion/powershell.go b/completion/powershell.go new file mode 100644 index 0000000..cc083bd --- /dev/null +++ b/completion/powershell.go @@ -0,0 +1,86 @@ +package completion + +import ( + "io" + "os/exec" + "strings" +) + +type powershell struct { + goos string + programName string +} + +var _ Shell = &powershell{} + +func (p *powershell) Name() string { + return "powershell" +} + +func Powershell(goos string, programName string) Shell { + return &powershell{goos: goos, programName: programName} +} + +func (p *powershell) InstallPath() (string, error) { + var ( + path []byte + err error + ) + cmd := "$PROFILE.CurrentUserAllHosts" + if p.goos == "windows" { + path, err = exec.Command("powershell", cmd).CombinedOutput() + } else { + path, err = exec.Command("pwsh", "-Command", cmd).CombinedOutput() + } + if err != nil { + return "", err + } + return strings.TrimSpace(string(path)), nil +} + +func (p *powershell) WriteCompletion(w io.Writer) error { + return writeConfig(w, pshCompletionTemplate, p.programName) +} + +func (p *powershell) ProgramName() string { + return p.programName +} + +const pshCompletionTemplate = ` +# Escaping output sourced from: +# https://github.com/spf13/cobra/blob/e94f6d0dd9a5e5738dca6bce03c4b1207ffbc0ec/powershell_completions.go#L47 +filter _{{.Name}}_escapeStringWithSpecialChars { +` + " $_ -replace '\\s|#|@|\\$|;|,|''|\\{|\\}|\\(|\\)|\"|`|\\||<|>|&','`$&'" + ` +} + +$_{{.Name}}_completions = { + param( + $wordToComplete, + $commandAst, + $cursorPosition + ) + # Legacy space handling sourced from: + # https://github.com/spf13/cobra/blob/e94f6d0dd9a5e5738dca6bce03c4b1207ffbc0ec/powershell_completions.go#L107 + if ($PSVersionTable.PsVersion -lt [version]'7.2.0' -or + ($PSVersionTable.PsVersion -lt [version]'7.3.0' -and -not [ExperimentalFeature]::IsEnabled("PSNativeCommandArgumentPassing")) -or + (($PSVersionTable.PsVersion -ge [version]'7.3.0' -or [ExperimentalFeature]::IsEnabled("PSNativeCommandArgumentPassing")) -and + $PSNativeCommandArgumentPassing -eq 'Legacy')) { + $Space =` + "' `\"`\"'" + ` + } else { + $Space = ' ""' + } + $Command = $commandAst.ToString().Substring(0, $cursorPosition - 1) + if ($wordToComplete -ne "" ) { + $wordToComplete = $Command.Split(" ")[-1] + } else { + $Command = $Command + $Space + } + # Get completions by calling the command with the COMPLETION_MODE environment variable set to 1 + $env:COMPLETION_MODE = 1 + Invoke-Expression $Command | Where-Object { $_ -like "$wordToComplete*" } | ForEach-Object { + "$_" | _{{.Name}}_escapeStringWithSpecialChars + } + $env:COMPLETION_MODE = '' +} +Register-ArgumentCompleter -CommandName {{.Name}} -ScriptBlock $_{{.Name}}_completions +` diff --git a/completion/zsh.go b/completion/zsh.go new file mode 100644 index 0000000..b2793b0 --- /dev/null +++ b/completion/zsh.go @@ -0,0 +1,49 @@ +package completion + +import ( + "io" + "path/filepath" + + home "github.com/mitchellh/go-homedir" +) + +type zsh struct { + goos string + programName string +} + +var _ Shell = &zsh{} + +func Zsh(goos string, programName string) Shell { + return &zsh{goos: goos, programName: programName} +} + +func (z *zsh) Name() string { + return "zsh" +} + +func (z *zsh) InstallPath() (string, error) { + homeDir, err := home.Dir() + if err != nil { + return "", err + } + return filepath.Join(homeDir, ".zshrc"), nil +} + +func (z *zsh) WriteCompletion(w io.Writer) error { + return writeConfig(w, zshCompletionTemplate, z.programName) +} + +func (z *zsh) ProgramName() string { + return z.programName +} + +const zshCompletionTemplate = ` +_{{.Name}}_completions() { + local -a args completions + args=("${words[@]:1:$#words}") + completions=(${(f)"$(COMPLETION_MODE=1 "{{.Name}}" "${args[@]}")"}) + compadd -a completions +} +compdef _{{.Name}}_completions {{.Name}} +` diff --git a/completion_test.go b/completion_test.go new file mode 100644 index 0000000..43d6e1b --- /dev/null +++ b/completion_test.go @@ -0,0 +1,341 @@ +package serpent_test + +import ( + "fmt" + "io" + "os" + "path/filepath" + "strings" + "testing" + + serpent "github.com/coder/serpent" + "github.com/coder/serpent/completion" + "github.com/stretchr/testify/require" +) + +func TestCompletion(t *testing.T) { + t.Parallel() + + cmd := func() *serpent.Command { return sampleCommand(t) } + + t.Run("SubcommandList", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "altfile\nfile\nrequired-flag\ntoupper\n", io.Stdout.String()) + }) + + t.Run("SubcommandNoPartial", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("f") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "altfile\nfile\nrequired-flag\ntoupper\n", io.Stdout.String()) + }) + + t.Run("SubcommandComplete", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "required-flag\n", io.Stdout.String()) + }) + + t.Run("ListFlags", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag", "-") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "--req-array\n--req-bool\n--req-enum\n--req-enum-array\n--req-string\n", io.Stdout.String()) + }) + + t.Run("ListFlagsAfterArg", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("altfile", "-") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "doesntexist.go\n--extra\n", io.Stdout.String()) + }) + + t.Run("FlagExhaustive", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag", "--req-bool", "--req-string", "foo bar", "--req-array", "asdf", "--req-array", "qwerty", "-") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "--req-array\n--req-enum\n--req-enum-array\n", io.Stdout.String()) + }) + + t.Run("FlagShorthand", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag", "-b", "-s", "foo bar", "-a", "asdf", "-") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "--req-array\n--req-enum\n--req-enum-array\n", io.Stdout.String()) + }) + + t.Run("NoOptDefValueFlag", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("--verbose", "-") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "--prefix\n", io.Stdout.String()) + }) + + t.Run("EnumOK", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag", "--req-enum", "") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "foo\nbar\nqux\n", io.Stdout.String()) + }) + + t.Run("EnumEqualsOK", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag", "--req-enum", "--req-enum=") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "--req-enum=foo\n--req-enum=bar\n--req-enum=qux\n", io.Stdout.String()) + }) + + t.Run("EnumEqualsBeginQuotesOK", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag", "--req-enum", "--req-enum=\"") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "--req-enum=foo\n--req-enum=bar\n--req-enum=qux\n", io.Stdout.String()) + }) + + t.Run("EnumArrayOK", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag", "--req-enum-array", "") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "foo\nbar\nqux\n", io.Stdout.String()) + }) + + t.Run("EnumArrayEqualsOK", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag", "--req-enum-array=") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "--req-enum-array=foo\n--req-enum-array=bar\n--req-enum-array=qux\n", io.Stdout.String()) + }) + + t.Run("EnumArrayEqualsBeginQuotesOK", func(t *testing.T) { + t.Parallel() + i := cmd().Invoke("required-flag", "--req-enum-array=\"") + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "--req-enum-array=foo\n--req-enum-array=bar\n--req-enum-array=qux\n", io.Stdout.String()) + }) + +} + +func TestFileCompletion(t *testing.T) { + t.Parallel() + + cmd := func() *serpent.Command { return sampleCommand(t) } + + t.Run("DirOK", func(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + i := cmd().Invoke("file", tempDir) + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, fmt.Sprintf("%s%c\n", tempDir, os.PathSeparator), io.Stdout.String()) + }) + + t.Run("EmptyDirOK", func(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + string(os.PathSeparator) + i := cmd().Invoke("file", tempDir) + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + require.Equal(t, "", io.Stdout.String()) + }) + + cases := []struct { + name string + realPath string + paths []string + }{ + { + name: "CurDirOK", + realPath: ".", + paths: []string{"", "./", "././"}, + }, + { + name: "PrevDirOK", + realPath: "..", + paths: []string{"../", ".././"}, + }, + { + name: "RootOK", + realPath: "/", + paths: []string{"/", "/././"}, + }, + } + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + for _, path := range tc.paths { + i := cmd().Invoke("file", path) + i.Environ.Set(serpent.CompletionModeEnv, "1") + io := fakeIO(i) + err := i.Run() + require.NoError(t, err) + output := strings.Split(io.Stdout.String(), "\n") + output = output[:len(output)-1] + for _, str := range output { + if strings.HasSuffix(str, string(os.PathSeparator)) { + require.DirExists(t, str) + } else { + require.FileExists(t, str) + } + } + files, err := os.ReadDir(tc.realPath) + require.NoError(t, err) + require.Equal(t, len(files), len(output)) + } + }) + } +} + +func TestCompletionInstall(t *testing.T) { + t.Parallel() + + t.Run("InstallingNew", func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "fake.sh") + shell := &fakeShell{baseInstallDir: dir, programName: "fake"} + + err := completion.InstallShellCompletion(shell) + require.NoError(t, err) + contents, err := os.ReadFile(path) + require.NoError(t, err) + require.Equal(t, "# ============ BEGIN fake COMPLETION ============\nFAKE_COMPLETION\n# ============ END fake COMPLETION ==============\n", string(contents)) + }) + + cases := []struct { + name string + input []byte + expected []byte + errMsg string + }{ + { + name: "InstallingAppend", + input: []byte("FAKE_SCRIPT"), + expected: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nFAKE_COMPLETION\n# ============ END fake COMPLETION ==============\n"), + }, + { + name: "InstallReplaceBeginning", + input: []byte("# ============ BEGIN fake COMPLETION ============\nOLD_COMPLETION\n# ============ END fake COMPLETION ==============\nFAKE_SCRIPT\n"), + expected: []byte("# ============ BEGIN fake COMPLETION ============\nFAKE_COMPLETION\n# ============ END fake COMPLETION ==============\nFAKE_SCRIPT\n"), + }, + { + name: "InstallReplaceMiddle", + input: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nOLD_COMPLETION\n# ============ END fake COMPLETION ==============\nFAKE_SCRIPT\n"), + expected: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nFAKE_COMPLETION\n# ============ END fake COMPLETION ==============\nFAKE_SCRIPT\n"), + }, + { + name: "InstallReplaceEnd", + input: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nOLD_COMPLETION\n# ============ END fake COMPLETION ==============\n"), + expected: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nFAKE_COMPLETION\n# ============ END fake COMPLETION ==============\n"), + }, + { + name: "InstallNoFooter", + input: []byte("FAKE_SCRIPT\n# ============ BEGIN fake COMPLETION ============\nOLD_COMPLETION\n"), + errMsg: "missing completion footer", + }, + { + name: "InstallNoHeader", + input: []byte("OLD_COMPLETION\n# ============ END fake COMPLETION ==============\n"), + errMsg: "missing completion header", + }, + { + name: "InstallBadOrder", + input: []byte("# ============ END fake COMPLETION ==============\nFAKE_COMPLETION\n# ============ BEGIN fake COMPLETION =============="), + errMsg: "header after footer", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "fake.sh") + err := os.WriteFile(path, tc.input, 0o644) + require.NoError(t, err) + + shell := &fakeShell{baseInstallDir: dir, programName: "fake"} + err = completion.InstallShellCompletion(shell) + if tc.errMsg != "" { + require.ErrorContains(t, err, tc.errMsg) + return + } else { + require.NoError(t, err) + contents, err := os.ReadFile(path) + require.NoError(t, err) + require.Equal(t, tc.expected, contents) + } + }) + } +} + +type fakeShell struct { + baseInstallDir string + programName string +} + +func (f *fakeShell) ProgramName() string { + return f.programName +} + +var _ completion.Shell = &fakeShell{} + +func (f *fakeShell) InstallPath() (string, error) { + return filepath.Join(f.baseInstallDir, "fake.sh"), nil +} + +func (f *fakeShell) Name() string { + return "Fake" +} + +func (f *fakeShell) WriteCompletion(w io.Writer) error { + _, err := w.Write([]byte("\nFAKE_COMPLETION\n")) + return err +} diff --git a/example/completetest/main.go b/example/completetest/main.go new file mode 100644 index 0000000..add6d5c --- /dev/null +++ b/example/completetest/main.go @@ -0,0 +1,130 @@ +package main + +import ( + "fmt" + "os" + "strings" + + "github.com/coder/serpent" + "github.com/coder/serpent/completion" +) + +// installCommand returns a serpent command that helps +// a user configure their shell to use serpent's completion. +func installCommand() *serpent.Command { + var shell string + return &serpent.Command{ + Use: "completion [--shell ]", + Short: "Generate completion scripts for the given shell.", + Handler: func(inv *serpent.Invocation) error { + defaultShell, err := completion.DetectUserShell(inv.Command.Parent.Name()) + if err != nil { + return fmt.Errorf("Could not detect user shell, please specify a shell using `--shell`") + } + return defaultShell.WriteCompletion(inv.Stdout) + }, + Options: serpent.OptionSet{ + { + Flag: "shell", + FlagShorthand: "s", + Description: "The shell to generate a completion script for.", + Value: completion.ShellOptions(&shell), + }, + }, + } +} + +func main() { + var ( + print bool + upper bool + fileType string + fileArr []string + types []string + ) + cmd := serpent.Command{ + Use: "completetest ", + Short: "Prints the given text to the console.", + Options: serpent.OptionSet{ + { + Name: "different", + Value: serpent.BoolOf(&upper), + Flag: "different", + Description: "Do the command differently.", + }, + }, + Handler: func(inv *serpent.Invocation) error { + if len(inv.Args) == 0 { + inv.Stderr.Write([]byte("error: missing text\n")) + os.Exit(1) + } + + text := inv.Args[0] + if upper { + text = strings.ToUpper(text) + } + + inv.Stdout.Write([]byte(text)) + return nil + }, + Children: []*serpent.Command{ + { + Use: "sub", + Short: "A subcommand", + Handler: func(inv *serpent.Invocation) error { + inv.Stdout.Write([]byte("subcommand")) + return nil + }, + Options: serpent.OptionSet{ + { + Name: "upper", + Value: serpent.BoolOf(&upper), + Flag: "upper", + Description: "Prints the text in upper case.", + }, + }, + }, + { + Use: "file ", + Handler: func(inv *serpent.Invocation) error { + return nil + }, + Options: serpent.OptionSet{ + { + Name: "print", + Value: serpent.BoolOf(&print), + Flag: "print", + Description: "Print the file.", + }, + { + Name: "type", + Value: serpent.EnumOf(&fileType, "binary", "text"), + Flag: "type", + Description: "The type of file.", + }, + { + Name: "extra", + Flag: "extra", + Description: "Extra files.", + Value: serpent.StringArrayOf(&fileArr), + }, + { + Name: "types", + Flag: "types", + Value: serpent.EnumArrayOf(&types, "binary", "text"), + }, + }, + CompletionHandler: completion.FileHandler(nil), + Middleware: serpent.RequireNArgs(1), + }, + installCommand(), + }, + } + + inv := cmd.Invoke().WithOS() + + err := inv.Run() + if err != nil { + panic(err) + } +} diff --git a/go.mod b/go.mod index 8bd432e..1c2880c 100644 --- a/go.mod +++ b/go.mod @@ -6,8 +6,10 @@ require ( cdr.dev/slog v1.6.2-0.20240126064726-20367d4aede6 github.com/coder/pretty v0.0.0-20230908205945-e89ba86370e0 github.com/hashicorp/go-multierror v1.1.1 + github.com/mitchellh/go-homedir v1.1.0 github.com/mitchellh/go-wordwrap v1.0.1 github.com/muesli/termenv v0.15.2 + github.com/natefinch/atomic v1.0.1 github.com/pion/udp v0.1.4 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.4 diff --git a/go.sum b/go.sum index 63d175d..a1106fc 100644 --- a/go.sum +++ b/go.sum @@ -46,12 +46,16 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0= github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0= github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo= github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8= +github.com/natefinch/atomic v1.0.1 h1:ZPYKxkqQOx3KZ+RsbnP/YsgvxWQPGxjC0oBt2AhwV0A= +github.com/natefinch/atomic v1.0.1/go.mod h1:N/D/ELrljoqDyT3rZrsUmtsuzvHkeB/wWjHV22AZRbM= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/transport/v2 v2.0.0 h1:bsMYyqHCbkvHwj+eNCFBuxtlKndKfyGI2vaQmM3fIE4= github.com/pion/transport/v2 v2.0.0/go.mod h1:HS2MEBJTwD+1ZI2eSXSvHJx/HnzQqRy2/LXxt6eVMHc= diff --git a/help.go b/help.go index 3dc49a4..2ec4201 100644 --- a/help.go +++ b/help.go @@ -95,6 +95,8 @@ var defaultHelpTemplate = func() *template.Template { switch v := opt.Value.(type) { case *Enum: return strings.Join(v.Choices, "|") + case *EnumArray: + return fmt.Sprintf("[%s]", strings.Join(v.Choices, "|")) default: return v.Type() } @@ -320,9 +322,17 @@ func (lm *newlineLimiter) Write(p []byte) (int, error) { var usageWantsArgRe = regexp.MustCompile(`<.*>`) -// defaultHelpFn returns a function that generates usage (help) +type UnknownSubcommandError struct { + Args []string +} + +func (e *UnknownSubcommandError) Error() string { + return fmt.Sprintf("unknown subcommand %q", strings.Join(e.Args, " ")) +} + +// DefaultHelpFn returns a function that generates usage (help) // output for a given command. -func defaultHelpFn() HandlerFunc { +func DefaultHelpFn() HandlerFunc { return func(inv *Invocation) error { // We use stdout for help and not stderr since there's no straightforward // way to distinguish between a user error and a help request. @@ -350,7 +360,7 @@ func defaultHelpFn() HandlerFunc { if len(inv.Args) > 0 { // Return an error so that exit status is non-zero when // a subcommand is not found. - return fmt.Errorf("error: unknown subcommand %q", strings.Join(inv.Args, " ")) + return &UnknownSubcommandError{Args: inv.Args} } return nil } diff --git a/option.go b/option.go index 5545d07..2780fc6 100644 --- a/option.go +++ b/option.go @@ -65,6 +65,8 @@ type Option struct { Hidden bool `json:"hidden,omitempty"` ValueSource ValueSource `json:"value_source,omitempty"` + + CompletionHandler CompletionHandlerFunc `json:"-"` } // optionNoMethods is just a wrapper around Option so we can defer to the @@ -335,10 +337,22 @@ func (optSet *OptionSet) SetDefaults() error { // ByName returns the Option with the given name, or nil if no such option // exists. -func (optSet *OptionSet) ByName(name string) *Option { - for i := range *optSet { - opt := &(*optSet)[i] - if opt.Name == name { +func (optSet OptionSet) ByName(name string) *Option { + for i := range optSet { + if optSet[i].Name == name { + return &optSet[i] + } + } + return nil +} + +func (optSet OptionSet) ByFlag(flag string) *Option { + if flag == "" { + return nil + } + for i := range optSet { + opt := &optSet[i] + if opt.Flag == flag { return opt } } diff --git a/values.go b/values.go index 554e9a6..79c8e2c 100644 --- a/values.go +++ b/values.go @@ -108,6 +108,30 @@ func (Int64) Type() string { return "int" } +type Float64 float64 + +func Float64Of(f *float64) *Float64 { + return (*Float64)(f) +} + +func (f *Float64) Set(s string) error { + ff, err := strconv.ParseFloat(s, 64) + *f = Float64(ff) + return err +} + +func (f Float64) Value() float64 { + return float64(f) +} + +func (f Float64) String() string { + return strconv.FormatFloat(float64(f), 'f', -1, 64) +} + +func (Float64) Type() string { + return "float64" +} + type Bool bool func BoolOf(b *bool) *Bool { @@ -167,7 +191,10 @@ func (String) Type() string { return "string" } -var _ pflag.SliceValue = &StringArray{} +var ( + _ pflag.SliceValue = &StringArray{} + _ pflag.Value = &StringArray{} +) // StringArray is a slice of strings that implements pflag.Value and pflag.SliceValue. type StringArray []string @@ -503,7 +530,7 @@ func EnumOf(v *string, choices ...string) *Enum { func (e *Enum) Set(v string) error { for _, c := range e.Choices { - if v == c { + if strings.EqualFold(v, c) { *e.Value = v return nil } @@ -604,3 +631,76 @@ func (p *YAMLConfigPath) String() string { func (*YAMLConfigPath) Type() string { return "yaml-config-path" } + +var _ pflag.SliceValue = (*EnumArray)(nil) +var _ pflag.Value = (*EnumArray)(nil) + +type EnumArray struct { + Choices []string + Value *[]string +} + +func (e *EnumArray) Append(s string) error { + for _, c := range e.Choices { + if strings.EqualFold(s, c) { + *e.Value = append(*e.Value, s) + return nil + } + } + return xerrors.Errorf("invalid choice: %s, should be one of %v", s, e.Choices) +} + +func (e *EnumArray) GetSlice() []string { + return *e.Value +} + +func (e *EnumArray) Replace(ss []string) error { + for _, s := range ss { + found := false + for _, c := range e.Choices { + if strings.EqualFold(s, c) { + found = true + break + } + } + if !found { + return xerrors.Errorf("invalid choice: %s, should be one of %v", s, e.Choices) + } + } + *e.Value = ss + return nil +} + +func (e *EnumArray) Set(v string) error { + if v == "" { + *e.Value = nil + return nil + } + ss, err := readAsCSV(v) + if err != nil { + return err + } + for _, s := range ss { + err := e.Append(s) + if err != nil { + return err + } + } + return nil +} + +func (e *EnumArray) String() string { + return writeAsCSV(*e.Value) +} + +func (e *EnumArray) Type() string { + return fmt.Sprintf("enum-array[%v]", strings.Join(e.Choices, "\\|")) +} + +func EnumArrayOf(v *[]string, choices ...string) *EnumArray { + choices = append([]string{}, choices...) + return &EnumArray{ + Choices: choices, + Value: v, + } +}