Replace *bytes.Buffer with io.Writer
Also adds support for generating bash completions on writers other than just buffers. Found via github.com/mvdan/interfacer.
This commit is contained in:
		@ -3,6 +3,7 @@ package cobra
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"bytes"
 | 
						"bytes"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
	"os"
 | 
						"os"
 | 
				
			||||||
	"sort"
 | 
						"sort"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
@ -16,7 +17,7 @@ const (
 | 
				
			|||||||
	BashCompSubdirsInDir    = "cobra_annotation_bash_completion_subdirs_in_dir"
 | 
						BashCompSubdirsInDir    = "cobra_annotation_bash_completion_subdirs_in_dir"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func preamble(out *bytes.Buffer) {
 | 
					func preamble(out io.Writer) {
 | 
				
			||||||
	fmt.Fprintf(out, `#!/bin/bash
 | 
						fmt.Fprintf(out, `#!/bin/bash
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__debug()
 | 
					__debug()
 | 
				
			||||||
@ -207,10 +208,10 @@ __handle_word()
 | 
				
			|||||||
`)
 | 
					`)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func postscript(out *bytes.Buffer, name string) {
 | 
					func postscript(w io.Writer, name string) {
 | 
				
			||||||
	name = strings.Replace(name, ":", "__", -1)
 | 
						name = strings.Replace(name, ":", "__", -1)
 | 
				
			||||||
	fmt.Fprintf(out, "__start_%s()\n", name)
 | 
						fmt.Fprintf(w, "__start_%s()\n", name)
 | 
				
			||||||
	fmt.Fprintf(out, `{
 | 
						fmt.Fprintf(w, `{
 | 
				
			||||||
    local cur prev words cword
 | 
					    local cur prev words cword
 | 
				
			||||||
    declare -A flaghash
 | 
					    declare -A flaghash
 | 
				
			||||||
    if declare -F _init_completion >/dev/null 2>&1; then
 | 
					    if declare -F _init_completion >/dev/null 2>&1; then
 | 
				
			||||||
@ -234,55 +235,55 @@ func postscript(out *bytes.Buffer, name string) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
`, name)
 | 
					`, name)
 | 
				
			||||||
	fmt.Fprintf(out, `if [[ $(type -t compopt) = "builtin" ]]; then
 | 
						fmt.Fprintf(w, `if [[ $(type -t compopt) = "builtin" ]]; then
 | 
				
			||||||
    complete -o default -F __start_%s %s
 | 
					    complete -o default -F __start_%s %s
 | 
				
			||||||
else
 | 
					else
 | 
				
			||||||
    complete -o default -o nospace -F __start_%s %s
 | 
					    complete -o default -o nospace -F __start_%s %s
 | 
				
			||||||
fi
 | 
					fi
 | 
				
			||||||
 | 
					
 | 
				
			||||||
`, name, name, name, name)
 | 
					`, name, name, name, name)
 | 
				
			||||||
	fmt.Fprintf(out, "# ex: ts=4 sw=4 et filetype=sh\n")
 | 
						fmt.Fprintf(w, "# ex: ts=4 sw=4 et filetype=sh\n")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func writeCommands(cmd *Command, out *bytes.Buffer) {
 | 
					func writeCommands(cmd *Command, w io.Writer) {
 | 
				
			||||||
	fmt.Fprintf(out, "    commands=()\n")
 | 
						fmt.Fprintf(w, "    commands=()\n")
 | 
				
			||||||
	for _, c := range cmd.Commands() {
 | 
						for _, c := range cmd.Commands() {
 | 
				
			||||||
		if !c.IsAvailableCommand() || c == cmd.helpCommand {
 | 
							if !c.IsAvailableCommand() || c == cmd.helpCommand {
 | 
				
			||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		fmt.Fprintf(out, "    commands+=(%q)\n", c.Name())
 | 
							fmt.Fprintf(w, "    commands+=(%q)\n", c.Name())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	fmt.Fprintf(out, "\n")
 | 
						fmt.Fprintf(w, "\n")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func writeFlagHandler(name string, annotations map[string][]string, out *bytes.Buffer) {
 | 
					func writeFlagHandler(name string, annotations map[string][]string, w io.Writer) {
 | 
				
			||||||
	for key, value := range annotations {
 | 
						for key, value := range annotations {
 | 
				
			||||||
		switch key {
 | 
							switch key {
 | 
				
			||||||
		case BashCompFilenameExt:
 | 
							case BashCompFilenameExt:
 | 
				
			||||||
			fmt.Fprintf(out, "    flags_with_completion+=(%q)\n", name)
 | 
								fmt.Fprintf(w, "    flags_with_completion+=(%q)\n", name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if len(value) > 0 {
 | 
								if len(value) > 0 {
 | 
				
			||||||
				ext := "__handle_filename_extension_flag " + strings.Join(value, "|")
 | 
									ext := "__handle_filename_extension_flag " + strings.Join(value, "|")
 | 
				
			||||||
				fmt.Fprintf(out, "    flags_completion+=(%q)\n", ext)
 | 
									fmt.Fprintf(w, "    flags_completion+=(%q)\n", ext)
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				ext := "_filedir"
 | 
									ext := "_filedir"
 | 
				
			||||||
				fmt.Fprintf(out, "    flags_completion+=(%q)\n", ext)
 | 
									fmt.Fprintf(w, "    flags_completion+=(%q)\n", ext)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		case BashCompSubdirsInDir:
 | 
							case BashCompSubdirsInDir:
 | 
				
			||||||
			fmt.Fprintf(out, "    flags_with_completion+=(%q)\n", name)
 | 
								fmt.Fprintf(w, "    flags_with_completion+=(%q)\n", name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if len(value) == 1 {
 | 
								if len(value) == 1 {
 | 
				
			||||||
				ext := "__handle_subdirs_in_dir_flag " + value[0]
 | 
									ext := "__handle_subdirs_in_dir_flag " + value[0]
 | 
				
			||||||
				fmt.Fprintf(out, "    flags_completion+=(%q)\n", ext)
 | 
									fmt.Fprintf(w, "    flags_completion+=(%q)\n", ext)
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				ext := "_filedir -d"
 | 
									ext := "_filedir -d"
 | 
				
			||||||
				fmt.Fprintf(out, "    flags_completion+=(%q)\n", ext)
 | 
									fmt.Fprintf(w, "    flags_completion+=(%q)\n", ext)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func writeShortFlag(flag *pflag.Flag, out *bytes.Buffer) {
 | 
					func writeShortFlag(flag *pflag.Flag, w io.Writer) {
 | 
				
			||||||
	b := (flag.Value.Type() == "bool")
 | 
						b := (flag.Value.Type() == "bool")
 | 
				
			||||||
	name := flag.Shorthand
 | 
						name := flag.Shorthand
 | 
				
			||||||
	format := "    "
 | 
						format := "    "
 | 
				
			||||||
@ -290,11 +291,11 @@ func writeShortFlag(flag *pflag.Flag, out *bytes.Buffer) {
 | 
				
			|||||||
		format += "two_word_"
 | 
							format += "two_word_"
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	format += "flags+=(\"-%s\")\n"
 | 
						format += "flags+=(\"-%s\")\n"
 | 
				
			||||||
	fmt.Fprintf(out, format, name)
 | 
						fmt.Fprintf(w, format, name)
 | 
				
			||||||
	writeFlagHandler("-"+name, flag.Annotations, out)
 | 
						writeFlagHandler("-"+name, flag.Annotations, w)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func writeFlag(flag *pflag.Flag, out *bytes.Buffer) {
 | 
					func writeFlag(flag *pflag.Flag, w io.Writer) {
 | 
				
			||||||
	b := (flag.Value.Type() == "bool")
 | 
						b := (flag.Value.Type() == "bool")
 | 
				
			||||||
	name := flag.Name
 | 
						name := flag.Name
 | 
				
			||||||
	format := "    flags+=(\"--%s"
 | 
						format := "    flags+=(\"--%s"
 | 
				
			||||||
@ -302,35 +303,35 @@ func writeFlag(flag *pflag.Flag, out *bytes.Buffer) {
 | 
				
			|||||||
		format += "="
 | 
							format += "="
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	format += "\")\n"
 | 
						format += "\")\n"
 | 
				
			||||||
	fmt.Fprintf(out, format, name)
 | 
						fmt.Fprintf(w, format, name)
 | 
				
			||||||
	writeFlagHandler("--"+name, flag.Annotations, out)
 | 
						writeFlagHandler("--"+name, flag.Annotations, w)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func writeFlags(cmd *Command, out *bytes.Buffer) {
 | 
					func writeFlags(cmd *Command, w io.Writer) {
 | 
				
			||||||
	fmt.Fprintf(out, `    flags=()
 | 
						fmt.Fprintf(w, `    flags=()
 | 
				
			||||||
    two_word_flags=()
 | 
					    two_word_flags=()
 | 
				
			||||||
    flags_with_completion=()
 | 
					    flags_with_completion=()
 | 
				
			||||||
    flags_completion=()
 | 
					    flags_completion=()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
`)
 | 
					`)
 | 
				
			||||||
	cmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) {
 | 
						cmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) {
 | 
				
			||||||
		writeFlag(flag, out)
 | 
							writeFlag(flag, w)
 | 
				
			||||||
		if len(flag.Shorthand) > 0 {
 | 
							if len(flag.Shorthand) > 0 {
 | 
				
			||||||
			writeShortFlag(flag, out)
 | 
								writeShortFlag(flag, w)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	cmd.InheritedFlags().VisitAll(func(flag *pflag.Flag) {
 | 
						cmd.InheritedFlags().VisitAll(func(flag *pflag.Flag) {
 | 
				
			||||||
		writeFlag(flag, out)
 | 
							writeFlag(flag, w)
 | 
				
			||||||
		if len(flag.Shorthand) > 0 {
 | 
							if len(flag.Shorthand) > 0 {
 | 
				
			||||||
			writeShortFlag(flag, out)
 | 
								writeShortFlag(flag, w)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	fmt.Fprintf(out, "\n")
 | 
						fmt.Fprintf(w, "\n")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func writeRequiredFlag(cmd *Command, out *bytes.Buffer) {
 | 
					func writeRequiredFlag(cmd *Command, w io.Writer) {
 | 
				
			||||||
	fmt.Fprintf(out, "    must_have_one_flag=()\n")
 | 
						fmt.Fprintf(w, "    must_have_one_flag=()\n")
 | 
				
			||||||
	flags := cmd.NonInheritedFlags()
 | 
						flags := cmd.NonInheritedFlags()
 | 
				
			||||||
	flags.VisitAll(func(flag *pflag.Flag) {
 | 
						flags.VisitAll(func(flag *pflag.Flag) {
 | 
				
			||||||
		for key := range flag.Annotations {
 | 
							for key := range flag.Annotations {
 | 
				
			||||||
@ -342,50 +343,50 @@ func writeRequiredFlag(cmd *Command, out *bytes.Buffer) {
 | 
				
			|||||||
					format += "="
 | 
										format += "="
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				format += "\")\n"
 | 
									format += "\")\n"
 | 
				
			||||||
				fmt.Fprintf(out, format, flag.Name)
 | 
									fmt.Fprintf(w, format, flag.Name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				if len(flag.Shorthand) > 0 {
 | 
									if len(flag.Shorthand) > 0 {
 | 
				
			||||||
					fmt.Fprintf(out, "    must_have_one_flag+=(\"-%s\")\n", flag.Shorthand)
 | 
										fmt.Fprintf(w, "    must_have_one_flag+=(\"-%s\")\n", flag.Shorthand)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func writeRequiredNoun(cmd *Command, out *bytes.Buffer) {
 | 
					func writeRequiredNoun(cmd *Command, w io.Writer) {
 | 
				
			||||||
	fmt.Fprintf(out, "    must_have_one_noun=()\n")
 | 
						fmt.Fprintf(w, "    must_have_one_noun=()\n")
 | 
				
			||||||
	sort.Sort(sort.StringSlice(cmd.ValidArgs))
 | 
						sort.Sort(sort.StringSlice(cmd.ValidArgs))
 | 
				
			||||||
	for _, value := range cmd.ValidArgs {
 | 
						for _, value := range cmd.ValidArgs {
 | 
				
			||||||
		fmt.Fprintf(out, "    must_have_one_noun+=(%q)\n", value)
 | 
							fmt.Fprintf(w, "    must_have_one_noun+=(%q)\n", value)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func gen(cmd *Command, out *bytes.Buffer) {
 | 
					func gen(cmd *Command, w io.Writer) {
 | 
				
			||||||
	for _, c := range cmd.Commands() {
 | 
						for _, c := range cmd.Commands() {
 | 
				
			||||||
		if !c.IsAvailableCommand() || c == cmd.helpCommand {
 | 
							if !c.IsAvailableCommand() || c == cmd.helpCommand {
 | 
				
			||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		gen(c, out)
 | 
							gen(c, w)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	commandName := cmd.CommandPath()
 | 
						commandName := cmd.CommandPath()
 | 
				
			||||||
	commandName = strings.Replace(commandName, " ", "_", -1)
 | 
						commandName = strings.Replace(commandName, " ", "_", -1)
 | 
				
			||||||
	commandName = strings.Replace(commandName, ":", "__", -1)
 | 
						commandName = strings.Replace(commandName, ":", "__", -1)
 | 
				
			||||||
	fmt.Fprintf(out, "_%s()\n{\n", commandName)
 | 
						fmt.Fprintf(w, "_%s()\n{\n", commandName)
 | 
				
			||||||
	fmt.Fprintf(out, "    last_command=%q\n", commandName)
 | 
						fmt.Fprintf(w, "    last_command=%q\n", commandName)
 | 
				
			||||||
	writeCommands(cmd, out)
 | 
						writeCommands(cmd, w)
 | 
				
			||||||
	writeFlags(cmd, out)
 | 
						writeFlags(cmd, w)
 | 
				
			||||||
	writeRequiredFlag(cmd, out)
 | 
						writeRequiredFlag(cmd, w)
 | 
				
			||||||
	writeRequiredNoun(cmd, out)
 | 
						writeRequiredNoun(cmd, w)
 | 
				
			||||||
	fmt.Fprintf(out, "}\n\n")
 | 
						fmt.Fprintf(w, "}\n\n")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (cmd *Command) GenBashCompletion(out *bytes.Buffer) {
 | 
					func (cmd *Command) GenBashCompletion(w io.Writer) {
 | 
				
			||||||
	preamble(out)
 | 
						preamble(w)
 | 
				
			||||||
	if len(cmd.BashCompletionFunction) > 0 {
 | 
						if len(cmd.BashCompletionFunction) > 0 {
 | 
				
			||||||
		fmt.Fprintf(out, "%s\n", cmd.BashCompletionFunction)
 | 
							fmt.Fprintf(w, "%s\n", cmd.BashCompletionFunction)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	gen(cmd, out)
 | 
						gen(cmd, w)
 | 
				
			||||||
	postscript(out, cmd.Name())
 | 
						postscript(w, cmd.Name())
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (cmd *Command) GenBashCompletionFile(filename string) error {
 | 
					func (cmd *Command) GenBashCompletionFile(filename string) error {
 | 
				
			||||||
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user