diff --git a/completer.go b/completer.go index c60604d..9005125 100644 --- a/completer.go +++ b/completer.go @@ -7,6 +7,8 @@ import ( "github.com/carapace-sh/carapace/pkg/style" completer "github.com/carapace-sh/carapace/pkg/x" "github.com/reeflective/readline" + "github.com/spf13/cobra" + "github.com/spf13/pflag" "github.com/reeflective/console/internal/completion" "github.com/reeflective/console/internal/line" @@ -22,6 +24,7 @@ func (c *Console) complete(input []rune, pos int) readline.Completions { // Split the line as shell words, only using // what the right buffer (up to the cursor) args, prefixComp, prefixLine := completion.SplitArgs(input, pos) + resetCompletionFlagState(menu.Command, args) // Prepare arguments for the carapace completer // (we currently need those two dummies for avoiding a panic). @@ -91,6 +94,84 @@ func (c *Console) complete(input []rune, pos int) readline.Completions { return comps } +func resetCompletionFlagState(root *cobra.Command, args []string) { + if root == nil { + return + } + + target := findCompletionTarget(root, args) + _ = target.LocalFlags() + resetCompletionFlagDefaults(target) + resetArgsLenAtDash(target) +} + +func resetCompletionFlagDefaults(target *cobra.Command) { + if target == nil { + return + } + + target.Flags().VisitAll(func(flag *pflag.Flag) { + flag.Changed = false + switch value := flag.Value.(type) { + case pflag.SliceValue: + var res []string + if len(flag.DefValue) > 0 && flag.DefValue != "[]" { + res = append(res, flag.DefValue) + } + + _ = value.Replace(res) + default: + _ = flag.Value.Set(flag.DefValue) + } + }) +} + +func resetArgsLenAtDash(target *cobra.Command) { + for cmd := target; cmd != nil; cmd = cmd.Parent() { + resetFlagSetArgsLenAtDash(cmd.Flags(), cmd.DisplayName()) + resetFlagSetArgsLenAtDash(cmd.PersistentFlags(), cmd.DisplayName()) + } +} + +func resetFlagSetArgsLenAtDash(fs *pflag.FlagSet, name string) { + if fs == nil { + return + } + + fs.Init(name, pflag.ContinueOnError) +} + +func findCompletionTarget(root *cobra.Command, args []string) *cobra.Command { + cmd := root + for _, arg := range args { + if arg == "--" || strings.HasPrefix(arg, "-") { + break + } + + next := findSubcommand(cmd, arg) + if next == nil { + break + } + cmd = next + } + + return cmd +} + +func findSubcommand(cmd *cobra.Command, name string) *cobra.Command { + if cmd == nil { + return nil + } + + for _, sub := range cmd.Commands() { + if sub.Name() == name || sub.HasAlias(name) { + return sub + } + } + + return nil +} + // justifyCommandComps justifies the descriptions for all commands in all groups // to the same level, for prettiness. Also, removes any coloring from them, as currently, // the carapace engine does add coloring to each group, and we don't want this. diff --git a/completer_test.go b/completer_test.go new file mode 100644 index 0000000..acf3f9b --- /dev/null +++ b/completer_test.go @@ -0,0 +1,55 @@ +package console + +import ( + "testing" + + "github.com/spf13/cobra" +) + +func TestCompleteResetsFlagDefaults(t *testing.T) { + c := New("test") + root := &cobra.Command{Use: "root"} + cmd := &cobra.Command{Use: "serve"} + cmd.Flags().Bool("verbose", false, "") + root.AddCommand(cmd) + c.activeMenu().Command = root + + if err := cmd.Flags().Set("verbose", "true"); err != nil { + t.Fatal(err) + } + + _ = c.complete([]rune("serve "), len("serve ")) + + flag := cmd.Flags().Lookup("verbose") + if flag == nil { + t.Fatal("missing verbose flag") + } + if flag.Changed { + t.Fatal("completion did not clear flag Changed state") + } + if flag.Value.String() != "false" { + t.Fatalf("flag value = %q, want false", flag.Value.String()) + } +} + +func TestCompleteResetsArgsLenAtDash(t *testing.T) { + c := New("test") + root := &cobra.Command{Use: "root"} + cmd := &cobra.Command{Use: "serve"} + cmd.Flags().Bool("verbose", false, "") + root.AddCommand(cmd) + c.activeMenu().Command = root + + if err := cmd.Flags().Parse([]string{"--", "positional"}); err != nil { + t.Fatal(err) + } + if got := cmd.Flags().ArgsLenAtDash(); got < 0 { + t.Fatalf("test setup did not set ArgsLenAtDash: %d", got) + } + + _ = c.complete([]rune("serve "), len("serve ")) + + if got := cmd.Flags().ArgsLenAtDash(); got != -1 { + t.Fatalf("ArgsLenAtDash = %d, want -1", got) + } +}