Merge pull request #215 from mvdan/interfacer

Replace *bytes.Buffer with io.Writer
This commit is contained in:
Eric Paris 2016-01-05 16:50:52 -05:00
commit c82a0ceef8

View File

@ -1,8 +1,8 @@
package cobra package cobra
import ( import (
"bytes"
"fmt" "fmt"
"io"
"os" "os"
"sort" "sort"
"strings" "strings"
@ -16,8 +16,8 @@ const (
BashCompSubdirsInDir = "cobra_annotation_bash_completion_subdirs_in_dir" BashCompSubdirsInDir = "cobra_annotation_bash_completion_subdirs_in_dir"
) )
func preamble(out *bytes.Buffer) { func preamble(out io.Writer) error {
fmt.Fprintf(out, `#!/bin/bash _, err := fmt.Fprintf(out, `#!/bin/bash
__debug() __debug()
{ {
@ -205,12 +205,16 @@ __handle_word()
} }
`) `)
return err
} }
func postscript(out *bytes.Buffer, name string) { func postscript(w io.Writer, name string) error {
name = strings.Replace(name, ":", "__", -1) name = strings.Replace(name, ":", "__", -1)
fmt.Fprintf(out, "__start_%s()\n", name) _, err := fmt.Fprintf(w, "__start_%s()\n", name)
fmt.Fprintf(out, `{ if err != nil {
return err
}
_, err = fmt.Fprintf(w, `{
local cur prev words cword local cur prev words cword
declare -A flaghash declare -A flaghash
if declare -F _init_completion >/dev/null 2>&1; then if declare -F _init_completion >/dev/null 2>&1; then
@ -234,55 +238,77 @@ func postscript(out *bytes.Buffer, name string) {
} }
`, name) `, name)
fmt.Fprintf(out, `if [[ $(type -t compopt) = "builtin" ]]; then if err != nil {
return err
}
_, err = fmt.Fprintf(w, `if [[ $(type -t compopt) = "builtin" ]]; then
complete -o default -F __start_%s %s complete -o default -F __start_%s %s
else else
complete -o default -o nospace -F __start_%s %s complete -o default -o nospace -F __start_%s %s
fi fi
`, name, name, name, name) `, name, name, name, name)
fmt.Fprintf(out, "# ex: ts=4 sw=4 et filetype=sh\n") if err != nil {
return err
}
_, err = fmt.Fprintf(w, "# ex: ts=4 sw=4 et filetype=sh\n")
return err
} }
func writeCommands(cmd *Command, out *bytes.Buffer) { func writeCommands(cmd *Command, w io.Writer) error {
fmt.Fprintf(out, " commands=()\n") if _, err := fmt.Fprintf(w, " commands=()\n"); err != nil {
return err
}
for _, c := range cmd.Commands() { for _, c := range cmd.Commands() {
if !c.IsAvailableCommand() || c == cmd.helpCommand { if !c.IsAvailableCommand() || c == cmd.helpCommand {
continue continue
} }
fmt.Fprintf(out, " commands+=(%q)\n", c.Name()) if _, err := fmt.Fprintf(w, " commands+=(%q)\n", c.Name()); err != nil {
return err
} }
fmt.Fprintf(out, "\n") }
_, err := fmt.Fprintf(w, "\n")
return err
} }
func writeFlagHandler(name string, annotations map[string][]string, out *bytes.Buffer) { func writeFlagHandler(name string, annotations map[string][]string, w io.Writer) error {
for key, value := range annotations { for key, value := range annotations {
switch key { switch key {
case BashCompFilenameExt: case BashCompFilenameExt:
fmt.Fprintf(out, " flags_with_completion+=(%q)\n", name) _, err := fmt.Fprintf(w, " flags_with_completion+=(%q)\n", name)
if err != nil {
return err
}
if len(value) > 0 { if len(value) > 0 {
ext := "__handle_filename_extension_flag " + strings.Join(value, "|") ext := "__handle_filename_extension_flag " + strings.Join(value, "|")
fmt.Fprintf(out, " flags_completion+=(%q)\n", ext) _, err = fmt.Fprintf(w, " flags_completion+=(%q)\n", ext)
} else { } else {
ext := "_filedir" ext := "_filedir"
fmt.Fprintf(out, " flags_completion+=(%q)\n", ext) _, err = fmt.Fprintf(w, " flags_completion+=(%q)\n", ext)
}
if err != nil {
return err
} }
case BashCompSubdirsInDir: case BashCompSubdirsInDir:
fmt.Fprintf(out, " flags_with_completion+=(%q)\n", name) _, err := fmt.Fprintf(w, " flags_with_completion+=(%q)\n", name)
if len(value) == 1 { if len(value) == 1 {
ext := "__handle_subdirs_in_dir_flag " + value[0] ext := "__handle_subdirs_in_dir_flag " + value[0]
fmt.Fprintf(out, " flags_completion+=(%q)\n", ext) _, err = fmt.Fprintf(w, " flags_completion+=(%q)\n", ext)
} else { } else {
ext := "_filedir -d" ext := "_filedir -d"
fmt.Fprintf(out, " flags_completion+=(%q)\n", ext) _, err = fmt.Fprintf(w, " flags_completion+=(%q)\n", ext)
}
if err != nil {
return err
} }
} }
} }
return nil
} }
func writeShortFlag(flag *pflag.Flag, out *bytes.Buffer) { func writeShortFlag(flag *pflag.Flag, w io.Writer) error {
b := (flag.Value.Type() == "bool") b := (flag.Value.Type() == "bool")
name := flag.Shorthand name := flag.Shorthand
format := " " format := " "
@ -290,11 +316,13 @@ func writeShortFlag(flag *pflag.Flag, out *bytes.Buffer) {
format += "two_word_" format += "two_word_"
} }
format += "flags+=(\"-%s\")\n" format += "flags+=(\"-%s\")\n"
fmt.Fprintf(out, format, name) if _, err := fmt.Fprintf(w, format, name); err != nil {
writeFlagHandler("-"+name, flag.Annotations, out) return err
}
return writeFlagHandler("-"+name, flag.Annotations, w)
} }
func writeFlag(flag *pflag.Flag, out *bytes.Buffer) { func writeFlag(flag *pflag.Flag, w io.Writer) error {
b := (flag.Value.Type() == "bool") b := (flag.Value.Type() == "bool")
name := flag.Name name := flag.Name
format := " flags+=(\"--%s" format := " flags+=(\"--%s"
@ -302,36 +330,64 @@ func writeFlag(flag *pflag.Flag, out *bytes.Buffer) {
format += "=" format += "="
} }
format += "\")\n" format += "\")\n"
fmt.Fprintf(out, format, name) if _, err := fmt.Fprintf(w, format, name); err != nil {
writeFlagHandler("--"+name, flag.Annotations, out) return err
}
return writeFlagHandler("--"+name, flag.Annotations, w)
} }
func writeFlags(cmd *Command, out *bytes.Buffer) { func writeFlags(cmd *Command, w io.Writer) error {
fmt.Fprintf(out, ` flags=() _, err := fmt.Fprintf(w, ` flags=()
two_word_flags=() two_word_flags=()
flags_with_completion=() flags_with_completion=()
flags_completion=() flags_completion=()
`) `)
if err != nil {
return err
}
var visitErr error
cmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) { cmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) {
writeFlag(flag, out) if err := writeFlag(flag, w); err != nil {
visitErr = err
return
}
if len(flag.Shorthand) > 0 { if len(flag.Shorthand) > 0 {
writeShortFlag(flag, out) if err := writeShortFlag(flag, w); err != nil {
visitErr = err
return
}
} }
}) })
if visitErr != nil {
return visitErr
}
cmd.InheritedFlags().VisitAll(func(flag *pflag.Flag) { cmd.InheritedFlags().VisitAll(func(flag *pflag.Flag) {
writeFlag(flag, out) if err := writeFlag(flag, w); err != nil {
visitErr = err
return
}
if len(flag.Shorthand) > 0 { if len(flag.Shorthand) > 0 {
writeShortFlag(flag, out) if err := writeShortFlag(flag, w); err != nil {
visitErr = err
return
}
} }
}) })
if visitErr != nil {
fmt.Fprintf(out, "\n") return visitErr
} }
func writeRequiredFlag(cmd *Command, out *bytes.Buffer) { _, err = fmt.Fprintf(w, "\n")
fmt.Fprintf(out, " must_have_one_flag=()\n") return err
}
func writeRequiredFlag(cmd *Command, w io.Writer) error {
if _, err := fmt.Fprintf(w, " must_have_one_flag=()\n"); err != nil {
return err
}
flags := cmd.NonInheritedFlags() flags := cmd.NonInheritedFlags()
var visitErr error
flags.VisitAll(func(flag *pflag.Flag) { flags.VisitAll(func(flag *pflag.Flag) {
for key := range flag.Annotations { for key := range flag.Annotations {
switch key { switch key {
@ -342,68 +398,95 @@ func writeRequiredFlag(cmd *Command, out *bytes.Buffer) {
format += "=" format += "="
} }
format += "\")\n" format += "\")\n"
fmt.Fprintf(out, format, flag.Name) if _, err := fmt.Fprintf(w, format, flag.Name); err != nil {
visitErr = err
return
}
if len(flag.Shorthand) > 0 { if len(flag.Shorthand) > 0 {
fmt.Fprintf(out, " must_have_one_flag+=(\"-%s\")\n", flag.Shorthand) if _, err := fmt.Fprintf(w, " must_have_one_flag+=(\"-%s\")\n", flag.Shorthand); err != nil {
visitErr = err
return
}
} }
} }
} }
}) })
return visitErr
} }
func writeRequiredNoun(cmd *Command, out *bytes.Buffer) { func writeRequiredNoun(cmd *Command, w io.Writer) error {
fmt.Fprintf(out, " must_have_one_noun=()\n") if _, err := fmt.Fprintf(w, " must_have_one_noun=()\n"); err != nil {
return err
}
sort.Sort(sort.StringSlice(cmd.ValidArgs)) sort.Sort(sort.StringSlice(cmd.ValidArgs))
for _, value := range cmd.ValidArgs { for _, value := range cmd.ValidArgs {
fmt.Fprintf(out, " must_have_one_noun+=(%q)\n", value) if _, err := fmt.Fprintf(w, " must_have_one_noun+=(%q)\n", value); err != nil {
return err
} }
} }
return nil
}
func gen(cmd *Command, out *bytes.Buffer) { func gen(cmd *Command, w io.Writer) error {
for _, c := range cmd.Commands() { for _, c := range cmd.Commands() {
if !c.IsAvailableCommand() || c == cmd.helpCommand { if !c.IsAvailableCommand() || c == cmd.helpCommand {
continue continue
} }
gen(c, out) if err := gen(c, w); err != nil {
return err
}
} }
commandName := cmd.CommandPath() commandName := cmd.CommandPath()
commandName = strings.Replace(commandName, " ", "_", -1) commandName = strings.Replace(commandName, " ", "_", -1)
commandName = strings.Replace(commandName, ":", "__", -1) commandName = strings.Replace(commandName, ":", "__", -1)
fmt.Fprintf(out, "_%s()\n{\n", commandName) if _, err := fmt.Fprintf(w, "_%s()\n{\n", commandName); err != nil {
fmt.Fprintf(out, " last_command=%q\n", commandName) return err
writeCommands(cmd, out) }
writeFlags(cmd, out) if _, err := fmt.Fprintf(w, " last_command=%q\n", commandName); err != nil {
writeRequiredFlag(cmd, out) return err
writeRequiredNoun(cmd, out) }
fmt.Fprintf(out, "}\n\n") if err := writeCommands(cmd, w); err != nil {
return err
}
if err := writeFlags(cmd, w); err != nil {
return err
}
if err := writeRequiredFlag(cmd, w); err != nil {
return err
}
if err := writeRequiredNoun(cmd, w); err != nil {
return err
}
if _, err := fmt.Fprintf(w, "}\n\n"); err != nil {
return err
}
return nil
} }
func (cmd *Command) GenBashCompletion(out *bytes.Buffer) { func (cmd *Command) GenBashCompletion(w io.Writer) error {
preamble(out) if err := preamble(w); err != nil {
if len(cmd.BashCompletionFunction) > 0 { return err
fmt.Fprintf(out, "%s\n", cmd.BashCompletionFunction)
} }
gen(cmd, out) if len(cmd.BashCompletionFunction) > 0 {
postscript(out, cmd.Name()) if _, err := fmt.Fprintf(w, "%s\n", cmd.BashCompletionFunction); err != nil {
return err
}
}
if err := gen(cmd, w); err != nil {
return err
}
return postscript(w, cmd.Name())
} }
func (cmd *Command) GenBashCompletionFile(filename string) error { func (cmd *Command) GenBashCompletionFile(filename string) error {
out := new(bytes.Buffer)
cmd.GenBashCompletion(out)
outFile, err := os.Create(filename) outFile, err := os.Create(filename)
if err != nil { if err != nil {
return err return err
} }
defer outFile.Close() defer outFile.Close()
_, err = outFile.Write(out.Bytes()) return cmd.GenBashCompletion(outFile)
if err != nil {
return err
}
return nil
} }
// MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag, if it exists. // MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag, if it exists.