diff --git a/README.md b/README.md index f4ada07..67873bf 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,8 @@ moq [flags] source-dir interface [interface2 [interface3 [...]]] go pretty-printer: gofmt, goimports or noop (default gofmt) -out string output file (default stdout) + -out-dir string + output dir (exclusive with -out) -pkg string package name (default will infer) -stub diff --git a/main.go b/main.go index 89adb3d..1699d95 100644 --- a/main.go +++ b/main.go @@ -18,6 +18,7 @@ var Version string = "dev" type userFlags struct { outFile string + outDir string pkgName string formatter string stubImpl bool @@ -30,6 +31,7 @@ type userFlags struct { func main() { var flags userFlags flag.StringVar(&flags.outFile, "out", "", "output file (default stdout)") + flag.StringVar(&flags.outDir, "out-dir", "", "output dir (exclusive with -out)") flag.StringVar(&flags.pkgName, "pkg", "", "package name (default will infer)") flag.StringVar(&flags.formatter, "fmt", "", "go pretty-printer: gofmt, goimports or noop (default gofmt)") flag.BoolVar(&flags.stubImpl, "stub", false, @@ -76,12 +78,6 @@ func run(flags userFlags) error { } } - var buf bytes.Buffer - var out io.Writer = os.Stdout - if flags.outFile != "" { - out = &buf - } - srcDir, args := flags.args[0], flags.args[1:] m, err := moq.New(moq.Config{ SrcDir: srcDir, @@ -95,19 +91,52 @@ func run(flags userFlags) error { return err } - if err = m.Mock(out, args...); err != nil { + switch { + case flags.outDir != "" && flags.outFile != "": + return errors.New("use only one of the flags -out and -out-dir") + case flags.outDir != "": + return mockToDir(m, flags.outDir, args...) + case flags.outFile != "": + return mockToFile(m, flags.outFile, args...) + default: + // mock to stdout + return m.Mock(os.Stdout, args...) + } +} + +func mockToDir(m *moq.Mocker, outDir string, args ...string) error { + if err := os.MkdirAll(outDir, 0o750); err != nil { return err } - if flags.outFile == "" { - return nil + var buf bytes.Buffer + for _, arg := range args { + if err := m.Mock(&buf, arg); err != nil { + return err + } + + filename := filepath.Join(outDir, m.FileMockName(arg)) + if err := ioutil.WriteFile(filename, buf.Bytes(), 0o600); err != nil { + return err + } + + buf.Reset() } - // create the file - err = os.MkdirAll(filepath.Dir(flags.outFile), 0o750) - if err != nil { + return nil +} + +func mockToFile(m *moq.Mocker, outFile string, args ...string) error { + var buf bytes.Buffer + var out io.Writer = &buf + + if err := m.Mock(out, args...); err != nil { + return err + } + + if err := os.MkdirAll(filepath.Dir(outFile), 0o750); err != nil { return err } - return ioutil.WriteFile(flags.outFile, buf.Bytes(), 0o600) + return ioutil.WriteFile(outFile, buf.Bytes(), 0o600) } diff --git a/pkg/moq/moq.go b/pkg/moq/moq.go index e8a2975..c7379ff 100644 --- a/pkg/moq/moq.go +++ b/pkg/moq/moq.go @@ -7,6 +7,7 @@ import ( "go/types" "io" "strings" + "unicode" "github.com/matryer/moq/internal/registry" "github.com/matryer/moq/internal/template" @@ -114,6 +115,19 @@ func (m *Mocker) Mock(w io.Writer, namePairs ...string) error { return nil } +// FileMockName generates file name for mock from interface +func (m *Mocker) FileMockName(namePair string) string { + ifaceName, mockName := parseInterfaceName(namePair) + + var mockFile string + if strings.HasPrefix(mockName, ifaceName) { + mockFile = toSnakeCase(ifaceName) + } else { + mockFile = toSnakeCase(mockName) + } + return mockFile + ".go" +} + func (m *Mocker) typeParams(tparams *types.TypeParamList) []template.TypeParamData { var tpd []template.TypeParamData if tparams == nil { @@ -210,3 +224,49 @@ func parseInterfaceName(namePair string) (ifaceName, mockName string) { ifaceName = parts[0] return ifaceName, ifaceName + "Mock" } + +func toSnakeCase(name string) string { + var buf bytes.Buffer + + fUpper := -1 + for i, r := range name { + if unicode.IsUpper(r) { + if fUpper < 0 { + fUpper = i + } + continue + } + + if fUpper < 0 { + // just next low rune + buf.WriteRune(r) + continue + } + + if fUpper == 0 && name[0] == 'I' { + // special case for interface preffix + buf.WriteString("i") + fUpper++ + } + if fUpper > 0 { + buf.WriteRune('_') + } + if i-fUpper >= 2 { + buf.WriteString(strings.ToLower(name[fUpper : i-1])) + buf.WriteRune('_') + } + + buf.WriteString(strings.ToLower(name[i-1 : i])) + buf.WriteRune(r) + fUpper = -1 + } + + if fUpper > 0 { + buf.WriteRune('_') + } + + if fUpper >= 0 { + buf.WriteString(strings.ToLower(name[fUpper:])) + } + return buf.String() +} diff --git a/pkg/moq/moq_test.go b/pkg/moq/moq_test.go index 7dceb8f..0320972 100644 --- a/pkg/moq/moq_test.go +++ b/pkg/moq/moq_test.go @@ -699,6 +699,31 @@ func TestMockError(t *testing.T) { } } +func TestToSnakeCase(t *testing.T) { + cases := []struct { + arg, want string + }{ + {"allcased", "allcased"}, + {"ALLUPPER", "allupper"}, + {"ICache", "i_cache"}, + {"IJournalService", "i_journal_service"}, + {"ITJournalService", "i_t_journal_service"}, + {"IAddressAPI", "i_address_api"}, + {"IABClient", "i_ab_client"}, + {"APIClient", "api_client"}, + {"IAPIClient", "i_api_client"}, + } + + for _, tc := range cases { + t.Run(tc.arg, func(t *testing.T) { + actual := toSnakeCase(tc.arg) + if actual != tc.want { + t.Errorf("expect: %v, actual: %v", tc.want, actual) + } + }) + } +} + // normalize normalizes \r\n (windows) and \r (mac) // into \n (unix) func normalize(d []byte) []byte {