diff --git a/flag_text.go b/flag_text.go new file mode 100644 index 0000000000..fa049714d0 --- /dev/null +++ b/flag_text.go @@ -0,0 +1,53 @@ +package cli + +import ( + "encoding" +) + +type TextMarshalUnmarshaler interface { + encoding.TextMarshaler + encoding.TextUnmarshaler +} + +// TextFlag enables you to set types that satisfies [TextMarshalUnmarshaler] using flags such as log levels. +type TextFlag = FlagBase[TextMarshalUnmarshaler, NoConfig, TextValue] + +type TextValue struct { + Value TextMarshalUnmarshaler +} + +func (f TextValue) String() string { + text, err := f.Value.MarshalText() + if err != nil { + return "" + } + + return string(text) +} + +func (f TextValue) Set(s string) error { + return f.Value.UnmarshalText([]byte(s)) +} + +func (f TextValue) Get() any { + return f.Value +} + +func (f TextValue) Create(v TextMarshalUnmarshaler, p *TextMarshalUnmarshaler, _ NoConfig) Value { + pp := *p + if v != nil { + if b, err := v.MarshalText(); err == nil { + _ = pp.UnmarshalText(b) + } + } + + return &TextValue{ + Value: pp, + } +} + +func (f TextValue) ToString(v TextMarshalUnmarshaler) string { + text, _ := v.MarshalText() + + return string(text) +} diff --git a/flag_text_test.go b/flag_text_test.go new file mode 100644 index 0000000000..780fc53631 --- /dev/null +++ b/flag_text_test.go @@ -0,0 +1,168 @@ +package cli + +import ( + "encoding" + "errors" + "io" + "log/slog" + "slices" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type badMarshaller struct{} + +func (badMarshaller) UnmarshalText(_ []byte) error { + return nil +} + +func (badMarshaller) MarshalText() ([]byte, error) { + return nil, errors.New("bad") +} + +func ptr[T any](v T) *T { + return &v +} + +func TestTextFlag(t *testing.T) { + tests := []struct { + name string + flag *TextFlag + args []string + want string + wantErr bool + }{ + { + name: "empty", + flag: &TextFlag{ + Name: "log-level", + Value: &slog.LevelVar{}, + Destination: ptr[TextMarshalUnmarshaler](&slog.LevelVar{}), + }, + want: "INFO", + }, + { + name: "info", + flag: &TextFlag{ + Name: "log-level", + Value: &slog.LevelVar{}, + Destination: ptr[TextMarshalUnmarshaler](&slog.LevelVar{}), + Validator: func(v TextMarshalUnmarshaler) error { + text, err := v.MarshalText() + if err != nil { + return err + } + + if !slices.Equal(text, []byte("INFO")) { + return errors.New("expected \"INFO\"") + } + + return nil + }, + }, + args: []string{"--log-level", "info"}, + want: "INFO", + }, + { + name: "debug", + flag: &TextFlag{ + Name: "log-level", + Value: &slog.LevelVar{}, + Destination: ptr[TextMarshalUnmarshaler](&slog.LevelVar{}), + }, + args: []string{"--log-level", "debug"}, + want: "DEBUG", + }, + { + name: "invalid", + flag: &TextFlag{ + Name: "log-level", + Value: &slog.LevelVar{}, + Destination: ptr[TextMarshalUnmarshaler](&slog.LevelVar{}), + }, + args: []string{"--log-level", "invalid"}, + want: "INFO", + wantErr: true, + }, + { + name: "bad_marshaller", + flag: &TextFlag{ + Name: "text", + Value: &badMarshaller{}, + Destination: ptr[TextMarshalUnmarshaler](&badMarshaller{}), + }, + args: []string{"--text", "foo"}, + wantErr: true, + }, + { + name: "default", + flag: &TextFlag{ + Name: "log-level", + Value: func() *slog.LevelVar { + var l slog.LevelVar + + l.Set(slog.LevelWarn) + + return &l + }(), + Destination: ptr[TextMarshalUnmarshaler](&slog.LevelVar{}), + }, + args: []string{}, + want: "WARN", + }, + { + name: "override_default", + flag: &TextFlag{ + Name: "log-level", + Value: func() *slog.LevelVar { + var l slog.LevelVar + + l.Set(slog.LevelWarn) + + return &l + }(), + Destination: ptr[TextMarshalUnmarshaler](&slog.LevelVar{}), + }, + args: []string{"--log-level", "error"}, + want: "ERROR", + }, + } + + t.Parallel() + + for _, tt := range tests { + t.Parallel() + + t.Run(tt.name, func(t *testing.T) { + cmd := &Command{ + Name: tt.name, + Flags: []Flag{tt.flag}, + Writer: io.Discard, + ErrWriter: io.Discard, + } + + err := cmd.Run(buildTestContext(t), append([]string{"mock"}, tt.args...)) + + if err != nil && !tt.wantErr { + require.NoError(t, err) + + return + } else if err != nil { + return + } + + var got []byte + + got, err = tt.flag.Get().(encoding.TextMarshaler).MarshalText() + if tt.wantErr { + require.Error(t, err) + + return + } + + assert.Equal(t, tt.want, string(got)) + }) + } +} diff --git a/godoc-current.txt b/godoc-current.txt index c6f50b2eaf..01cf3ea28a 100644 --- a/godoc-current.txt +++ b/godoc-current.txt @@ -1347,6 +1347,29 @@ type SuggestCommandFunc func(commands []*Command, provided string) string type SuggestFlagFunc func(flags []Flag, provided string, hideHelp bool) string +type TextFlag = FlagBase[TextMarshalUnmarshaler, NoConfig, TextValue] + TextFlag enables you to set types that satisfies TextMarshalUnmarshaler + using flags such as log levels. + +type TextMarshalUnmarshaler interface { + encoding.TextMarshaler + encoding.TextUnmarshaler +} + +type TextValue struct { + Value TextMarshalUnmarshaler +} + +func (f TextValue) Create(v TextMarshalUnmarshaler, p *TextMarshalUnmarshaler, _ NoConfig) Value + +func (f TextValue) Get() any + +func (f TextValue) Set(s string) error + +func (f TextValue) String() string + +func (f TextValue) ToString(v TextMarshalUnmarshaler) string + type TimestampArg = ArgumentBase[time.Time, TimestampConfig, timestampValue] type TimestampArgs = ArgumentsBase[time.Time, TimestampConfig, timestampValue] diff --git a/testdata/godoc-v3.x.txt b/testdata/godoc-v3.x.txt index c6f50b2eaf..01cf3ea28a 100644 --- a/testdata/godoc-v3.x.txt +++ b/testdata/godoc-v3.x.txt @@ -1347,6 +1347,29 @@ type SuggestCommandFunc func(commands []*Command, provided string) string type SuggestFlagFunc func(flags []Flag, provided string, hideHelp bool) string +type TextFlag = FlagBase[TextMarshalUnmarshaler, NoConfig, TextValue] + TextFlag enables you to set types that satisfies TextMarshalUnmarshaler + using flags such as log levels. + +type TextMarshalUnmarshaler interface { + encoding.TextMarshaler + encoding.TextUnmarshaler +} + +type TextValue struct { + Value TextMarshalUnmarshaler +} + +func (f TextValue) Create(v TextMarshalUnmarshaler, p *TextMarshalUnmarshaler, _ NoConfig) Value + +func (f TextValue) Get() any + +func (f TextValue) Set(s string) error + +func (f TextValue) String() string + +func (f TextValue) ToString(v TextMarshalUnmarshaler) string + type TimestampArg = ArgumentBase[time.Time, TimestampConfig, timestampValue] type TimestampArgs = ArgumentsBase[time.Time, TimestampConfig, timestampValue]