From c22c05ea008740df941e4834f4e354b2e47bab50 Mon Sep 17 00:00:00 2001 From: Victoria Dye Date: Wed, 22 Feb 2023 11:36:32 -0800 Subject: [PATCH 01/10] argparse: do not rely on os.Exit() inside FlagSet.Parse() When the internal 'FlagSet.Parse()' of an 'argparser' encounters an invalid flag, the current 'ExitOnError' error handling causes it to invoke the 'os.Exit(2)' syscall and exit the program abruptly. In later patches, we're going to want to "catch" that exit status in our logs, so change the error handling to 'ContinueOnError' and explicitly call 'os.Exit(2)' in 'argparse.Parse()' if 'FlagSet.Parse()' returns an error. Signed-off-by: Victoria Dye --- internal/argparse/argparse.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/internal/argparse/argparse.go b/internal/argparse/argparse.go index 3048469..77d1d2c 100644 --- a/internal/argparse/argparse.go +++ b/internal/argparse/argparse.go @@ -7,6 +7,9 @@ import ( "strings" ) +// For consistency with 'flag', use 2 as the usage-related error code +const usageExitCode int = 2 + type positionalArg struct { name string description string @@ -30,7 +33,7 @@ type argParser struct { } func NewArgParser(usageString string) *argParser { - flagSet := flag.NewFlagSet("", flag.ExitOnError) + flagSet := flag.NewFlagSet("", flag.ContinueOnError) a := &argParser{ isTopLevel: false, @@ -136,7 +139,9 @@ func (a *argParser) Parse(args []string) { err := a.FlagSet.Parse(args) if err != nil { - panic("argParser FlagSet error handling should be 'ExitOnError', but error encountered") + // The error was already printed (via a.FlagSet.Usage()), so we + // just need to exit + os.Exit(usageExitCode) } if len(a.subcommands) > 0 { @@ -212,6 +217,5 @@ func (a *argParser) Usage(errFmt string, args ...any) { fmt.Fprintf(a.FlagSet.Output(), errFmt+"\n", args...) a.FlagSet.Usage() - // Exit with error code 2 to match flag.Parse() behavior - os.Exit(2) + os.Exit(usageExitCode) } From e8713a4bd2fb215aebf649796cede3128eb1f1bc Mon Sep 17 00:00:00 2001 From: Victoria Dye Date: Tue, 21 Feb 2023 10:57:03 -0800 Subject: [PATCH 02/10] git-bundle-server: propagate 'context.Context' through call stack The 'context.Context' structure is used, among other things, to hold request, command, function, etc. scoped information for throughout the call stack. In particular, this information scoping will be useful for logging (e.g., holding a session ID or a region nesting level) added in future patches. Following Go convention [1], pass the 'ctx' from the 'main()' function of the 'git-bundle-server' into child structures that will implement the initial round of logging in the application (i.e., the arg parser, 'git-bundle-server' commands, and daemon providers). [1] https://pkg.go.dev/context Signed-off-by: Victoria Dye --- cmd/git-bundle-server/delete.go | 5 +++-- cmd/git-bundle-server/init.go | 5 +++-- cmd/git-bundle-server/main.go | 7 +++++-- cmd/git-bundle-server/start.go | 5 +++-- cmd/git-bundle-server/stop.go | 6 ++++-- cmd/git-bundle-server/update-all.go | 5 +++-- cmd/git-bundle-server/update.go | 5 +++-- cmd/git-bundle-server/web-server.go | 25 +++++++++++++------------ cmd/git-bundle-web-server/main.go | 8 +++++--- cmd/utils/common-args.go | 11 ++++++----- internal/argparse/argparse.go | 15 ++++++++------- internal/argparse/subcommand.go | 12 +++++++----- internal/daemon/daemon.go | 9 +++++---- internal/daemon/launchd.go | 23 ++++++++++++----------- internal/daemon/launchd_test.go | 21 +++++++++++++++------ internal/daemon/systemd.go | 15 ++++++++------- internal/daemon/systemd_test.go | 25 +++++++++++++++++-------- 17 files changed, 120 insertions(+), 82 deletions(-) diff --git a/cmd/git-bundle-server/delete.go b/cmd/git-bundle-server/delete.go index 6233667..ca48ec0 100644 --- a/cmd/git-bundle-server/delete.go +++ b/cmd/git-bundle-server/delete.go @@ -1,6 +1,7 @@ package main import ( + "context" "os" "github.com/github/git-bundle-server/internal/argparse" @@ -19,10 +20,10 @@ Remove the configuration for the given '' and delete its repository data.` } -func (Delete) Run(args []string) error { +func (Delete) Run(ctx context.Context, args []string) error { parser := argparse.NewArgParser("git-bundle-server delete ") route := parser.PositionalString("route", "the route to delete") - parser.Parse(args) + parser.Parse(ctx, args) repo, err := core.CreateRepository(*route) if err != nil { diff --git a/cmd/git-bundle-server/init.go b/cmd/git-bundle-server/init.go index 0daf65e..7ee9e56 100644 --- a/cmd/git-bundle-server/init.go +++ b/cmd/git-bundle-server/init.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "github.com/github/git-bundle-server/internal/argparse" @@ -21,12 +22,12 @@ Initialize a repository by cloning a bare repo from '', whose bundles should be hosted at ''.` } -func (Init) Run(args []string) error { +func (Init) Run(ctx context.Context, args []string) error { parser := argparse.NewArgParser("git-bundle-server init ") url := parser.PositionalString("url", "the URL of a repository to clone") // TODO: allow parsing out of route := parser.PositionalString("route", "the route to host the specified repo") - parser.Parse(args) + parser.Parse(ctx, args) repo, err := core.CreateRepository(*route) if err != nil { diff --git a/cmd/git-bundle-server/main.go b/cmd/git-bundle-server/main.go index 5ed5b51..1a0f64f 100644 --- a/cmd/git-bundle-server/main.go +++ b/cmd/git-bundle-server/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "log" "os" @@ -20,6 +21,8 @@ func all() []argparse.Subcommand { } func main() { + ctx := context.Background() + cmds := all() parser := argparse.NewArgParser("git-bundle-server []") @@ -27,9 +30,9 @@ func main() { for _, cmd := range cmds { parser.Subcommand(cmd) } - parser.Parse(os.Args[1:]) + parser.Parse(ctx, os.Args[1:]) - err := parser.InvokeSubcommand() + err := parser.InvokeSubcommand(ctx) if err != nil { log.Fatal("Failed with error: ", err) } diff --git a/cmd/git-bundle-server/start.go b/cmd/git-bundle-server/start.go index 3c6270e..6123bc9 100644 --- a/cmd/git-bundle-server/start.go +++ b/cmd/git-bundle-server/start.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "os" @@ -20,10 +21,10 @@ Start computing bundles and serving content for the repository at the specified ''.` } -func (Start) Run(args []string) error { +func (Start) Run(ctx context.Context, args []string) error { parser := argparse.NewArgParser("git-bundle-server start ") route := parser.PositionalString("route", "the route for which bundles should be generated") - parser.Parse(args) + parser.Parse(ctx, args) // CreateRepository registers the route. repo, err := core.CreateRepository(*route) diff --git a/cmd/git-bundle-server/stop.go b/cmd/git-bundle-server/stop.go index dec74ac..920ae54 100644 --- a/cmd/git-bundle-server/stop.go +++ b/cmd/git-bundle-server/stop.go @@ -1,6 +1,8 @@ package main import ( + "context" + "github.com/github/git-bundle-server/internal/argparse" "github.com/github/git-bundle-server/internal/core" ) @@ -17,10 +19,10 @@ Stop computing bundles or serving content for the repository at the specified ''.` } -func (Stop) Run(args []string) error { +func (Stop) Run(ctx context.Context, args []string) error { parser := argparse.NewArgParser("git-bundle-server stop ") route := parser.PositionalString("route", "the route for which bundles should stop being generated") - parser.Parse(args) + parser.Parse(ctx, args) return core.RemoveRoute(*route) } diff --git a/cmd/git-bundle-server/update-all.go b/cmd/git-bundle-server/update-all.go index 671dac4..79c9300 100644 --- a/cmd/git-bundle-server/update-all.go +++ b/cmd/git-bundle-server/update-all.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "os" "os/exec" @@ -21,7 +22,7 @@ func (UpdateAll) Description() string { For every configured route, run 'git-bundle-server update '.` } -func (UpdateAll) Run(args []string) error { +func (UpdateAll) Run(ctx context.Context, args []string) error { user, err := common.NewUserProvider().CurrentUser() if err != nil { return err @@ -29,7 +30,7 @@ func (UpdateAll) Run(args []string) error { fs := common.NewFileSystem() parser := argparse.NewArgParser("git-bundle-server update-all") - parser.Parse(args) + parser.Parse(ctx, args) exe, err := os.Executable() if err != nil { diff --git a/cmd/git-bundle-server/update.go b/cmd/git-bundle-server/update.go index 514f610..bc52dcb 100644 --- a/cmd/git-bundle-server/update.go +++ b/cmd/git-bundle-server/update.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "github.com/github/git-bundle-server/internal/argparse" @@ -21,10 +22,10 @@ For the repository in the current directory (or the one specified by bundles, and update the bundle list.` } -func (Update) Run(args []string) error { +func (Update) Run(ctx context.Context, args []string) error { parser := argparse.NewArgParser("git-bundle-server update ") route := parser.PositionalString("route", "the route to update") - parser.Parse(args) + parser.Parse(ctx, args) repo, err := core.CreateRepository(*route) if err != nil { diff --git a/cmd/git-bundle-server/web-server.go b/cmd/git-bundle-server/web-server.go index c79aa6f..6130673 100644 --- a/cmd/git-bundle-server/web-server.go +++ b/cmd/git-bundle-server/web-server.go @@ -1,6 +1,7 @@ package main import ( + "context" "errors" "flag" "fmt" @@ -76,7 +77,7 @@ func (w *webServer) getDaemonConfig() (*daemon.DaemonConfig, error) { }, nil } -func (w *webServer) startServer(args []string) error { +func (w *webServer) startServer(ctx context.Context, args []string) error { // Parse subcommand arguments parser := argparse.NewArgParser("git-bundle-server web-server start [-f|--force]") @@ -90,8 +91,8 @@ func (w *webServer) startServer(args []string) error { parser.Var(f.Value, f.Name, fmt.Sprintf("[Web server] %s", f.Usage)) }) - parser.Parse(args) - validate() + parser.Parse(ctx, args) + validate(ctx) d, err := daemon.NewDaemonProvider(w.user, w.cmdExec, w.fileSystem) if err != nil { @@ -130,12 +131,12 @@ func (w *webServer) startServer(args []string) error { return loopErr } - err = d.Create(config, *force) + err = d.Create(ctx, config, *force) if err != nil { return err } - err = d.Start(config.Label) + err = d.Start(ctx, config.Label) if err != nil { return err } @@ -143,11 +144,11 @@ func (w *webServer) startServer(args []string) error { return nil } -func (w *webServer) stopServer(args []string) error { +func (w *webServer) stopServer(ctx context.Context, args []string) error { // Parse subcommand arguments parser := argparse.NewArgParser("git-bundle-server web-server stop [--remove]") remove := parser.Bool("remove", false, "Remove the web server daemon configuration from the system after stopping") - parser.Parse(args) + parser.Parse(ctx, args) d, err := daemon.NewDaemonProvider(w.user, w.cmdExec, w.fileSystem) if err != nil { @@ -159,13 +160,13 @@ func (w *webServer) stopServer(args []string) error { return err } - err = d.Stop(config.Label) + err = d.Stop(ctx, config.Label) if err != nil { return err } if *remove { - err = d.Remove(config.Label) + err = d.Remove(ctx, config.Label) if err != nil { return err } @@ -174,12 +175,12 @@ func (w *webServer) stopServer(args []string) error { return nil } -func (w *webServer) Run(args []string) error { +func (w *webServer) Run(ctx context.Context, args []string) error { // Parse command arguments parser := argparse.NewArgParser("git-bundle-server web-server (start|stop) ") parser.Subcommand(argparse.NewSubcommand("start", "Start the web server", w.startServer)) parser.Subcommand(argparse.NewSubcommand("stop", "Stop the web server", w.stopServer)) - parser.Parse(args) + parser.Parse(ctx, args) - return parser.InvokeSubcommand() + return parser.InvokeSubcommand(ctx) } diff --git a/cmd/git-bundle-web-server/main.go b/cmd/git-bundle-web-server/main.go index b52ff2d..2f5589f 100644 --- a/cmd/git-bundle-web-server/main.go +++ b/cmd/git-bundle-web-server/main.go @@ -108,13 +108,15 @@ func startServer(server *http.Server, } func main() { + ctx := context.Background() + parser := argparse.NewArgParser("git-bundle-web-server [--port ] [--cert --key ]") flags, validate := utils.WebServerFlags(parser) flags.VisitAll(func(f *flag.Flag) { parser.Var(f.Value, f.Name, f.Usage) }) - parser.Parse(os.Args[1:]) - validate() + parser.Parse(ctx, os.Args[1:]) + validate(ctx) // Get the flag values port := utils.GetFlagValue[string](parser, "port") @@ -139,7 +141,7 @@ func main() { go func() { <-c fmt.Println("Starting graceful server shutdown...") - server.Shutdown(context.Background()) + server.Shutdown(ctx) }() // Wait for server to shut down diff --git a/cmd/utils/common-args.go b/cmd/utils/common-args.go index 9f6af20..2331ae8 100644 --- a/cmd/utils/common-args.go +++ b/cmd/utils/common-args.go @@ -1,6 +1,7 @@ package utils import ( + "context" "flag" "fmt" "strconv" @@ -14,7 +15,7 @@ import ( // functions we want to call from the parser. type argParser interface { Lookup(name string) *flag.Flag - Usage(errFmt string, args ...any) + Usage(ctx context.Context, errFmt string, args ...any) } func GetFlagValue[T any](parser argParser, name string) T { @@ -38,20 +39,20 @@ func GetFlagValue[T any](parser argParser, name string) T { // Sets of flags shared between multiple commands/programs -func WebServerFlags(parser argParser) (*flag.FlagSet, func()) { +func WebServerFlags(parser argParser) (*flag.FlagSet, func(context.Context)) { f := flag.NewFlagSet("", flag.ContinueOnError) port := f.String("port", "8080", "The port on which the server should be hosted") cert := f.String("cert", "", "The path to the X.509 SSL certificate file to use in securely hosting the server") key := f.String("key", "", "The path to the certificate's private key") // Function to call for additional arg validation (may exit with 'Usage()') - validationFunc := func() { + validationFunc := func(ctx context.Context) { p, err := strconv.Atoi(*port) if err != nil || p < 0 || p > 65535 { - parser.Usage("Invalid port '%s'.", *port) + parser.Usage(ctx, "Invalid port '%s'.", *port) } if (*cert == "") != (*key == "") { - parser.Usage("Both '--cert' and '--key' are needed to specify SSL configuration.") + parser.Usage(ctx, "Both '--cert' and '--key' are needed to specify SSL configuration.") } } diff --git a/internal/argparse/argparse.go b/internal/argparse/argparse.go index 77d1d2c..1f42d7e 100644 --- a/internal/argparse/argparse.go +++ b/internal/argparse/argparse.go @@ -1,6 +1,7 @@ package argparse import ( + "context" "flag" "fmt" "os" @@ -117,7 +118,7 @@ func (a *argParser) PositionalList(name string, description string) *[]string { return arg } -func (a *argParser) Parse(args []string) { +func (a *argParser) Parse(ctx context.Context, args []string) { if a.parsed { // Do nothing if we've already parsed args return @@ -147,12 +148,12 @@ func (a *argParser) Parse(args []string) { if len(a.subcommands) > 0 { // Parse subcommand, if applicable if a.FlagSet.NArg() == 0 { - a.Usage("Please specify a subcommand") + a.Usage(ctx, "Please specify a subcommand") } subcommand, exists := a.subcommands[a.FlagSet.Arg(0)] if !exists { - a.Usage("Invalid subcommand '%s'", a.FlagSet.Arg(0)) + a.Usage(ctx, "Invalid subcommand '%s'", a.FlagSet.Arg(0)) } else { a.selectedSubcommand = subcommand a.argOffset++ @@ -182,7 +183,7 @@ func (a *argParser) Parse(args []string) { if a.NArg() != 0 { // If not using subcommands, all args should be accounted for // Exit with usage if not - a.Usage("Unused arguments specified: %s", strings.Join(a.Args(), " ")) + a.Usage(ctx, "Unused arguments specified: %s", strings.Join(a.Args(), " ")) } } @@ -205,15 +206,15 @@ func (a *argParser) NArg() int { } } -func (a *argParser) InvokeSubcommand() error { +func (a *argParser) InvokeSubcommand(ctx context.Context) error { if !a.parsed || a.selectedSubcommand == nil { panic("subcommand has not been parsed") } - return a.selectedSubcommand.Run(a.Args()) + return a.selectedSubcommand.Run(ctx, a.Args()) } -func (a *argParser) Usage(errFmt string, args ...any) { +func (a *argParser) Usage(ctx context.Context, errFmt string, args ...any) { fmt.Fprintf(a.FlagSet.Output(), errFmt+"\n", args...) a.FlagSet.Usage() diff --git a/internal/argparse/subcommand.go b/internal/argparse/subcommand.go index 5e8b038..77f7b8c 100644 --- a/internal/argparse/subcommand.go +++ b/internal/argparse/subcommand.go @@ -1,21 +1,23 @@ package argparse +import "context" + type Subcommand interface { Name() string Description() string - Run(args []string) error + Run(ctx context.Context, args []string) error } type genericSubcommand struct { nameStr string descriptionStr string - runFunc func([]string) error + runFunc func(context.Context, []string) error } func NewSubcommand( name string, description string, - runFunc func([]string) error, + runFunc func(context.Context, []string) error, ) *genericSubcommand { return &genericSubcommand{ nameStr: name, @@ -32,6 +34,6 @@ func (s *genericSubcommand) Description() string { return s.descriptionStr } -func (s *genericSubcommand) Run(args []string) error { - return s.runFunc(args) +func (s *genericSubcommand) Run(ctx context.Context, args []string) error { + return s.runFunc(ctx, args) } diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index c3a2c0b..7d5d1ba 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -1,6 +1,7 @@ package daemon import ( + "context" "fmt" "runtime" @@ -15,13 +16,13 @@ type DaemonConfig struct { } type DaemonProvider interface { - Create(config *DaemonConfig, force bool) error + Create(ctx context.Context, config *DaemonConfig, force bool) error - Start(label string) error + Start(ctx context.Context, label string) error - Stop(label string) error + Stop(ctx context.Context, label string) error - Remove(label string) error + Remove(ctx context.Context, label string) error } func NewDaemonProvider( diff --git a/internal/daemon/launchd.go b/internal/daemon/launchd.go index 235c86e..2d4a12c 100644 --- a/internal/daemon/launchd.go +++ b/internal/daemon/launchd.go @@ -2,6 +2,7 @@ package daemon import ( "bytes" + "context" "encoding/xml" "fmt" "path/filepath" @@ -107,7 +108,7 @@ func NewLaunchdProvider( } } -func (l *launchd) isBootstrapped(serviceTarget string) (bool, error) { +func (l *launchd) isBootstrapped(ctx context.Context, serviceTarget string) (bool, error) { // run 'launchctl print' on given service target to see if it exists exitCode, err := l.cmdExec.Run("launchctl", "print", serviceTarget) if err != nil { @@ -124,7 +125,7 @@ func (l *launchd) isBootstrapped(serviceTarget string) (bool, error) { } } -func (l *launchd) bootstrapFile(domain string, filename string) error { +func (l *launchd) bootstrapFile(ctx context.Context, domain string, filename string) error { // run 'launchctl bootstrap' on given domain & file exitCode, err := l.cmdExec.Run("launchctl", "bootstrap", domain, filename) if err != nil { @@ -138,7 +139,7 @@ func (l *launchd) bootstrapFile(domain string, filename string) error { return nil } -func (l *launchd) bootout(serviceTarget string) (bool, error) { +func (l *launchd) bootout(ctx context.Context, serviceTarget string) (bool, error) { // run 'launchctl bootout' on given service target exitCode, err := l.cmdExec.Run("launchctl", "bootout", serviceTarget) if err != nil { @@ -154,7 +155,7 @@ func (l *launchd) bootout(serviceTarget string) (bool, error) { } } -func (l *launchd) Create(config *DaemonConfig, force bool) error { +func (l *launchd) Create(ctx context.Context, config *DaemonConfig, force bool) error { // Add launchd-specific config lConfig := &launchdConfig{ DaemonConfig: *config, @@ -184,7 +185,7 @@ func (l *launchd) Create(config *DaemonConfig, force bool) error { domainTarget := fmt.Sprintf(domainFormat, user.Uid) serviceTarget := fmt.Sprintf("%s/%s", domainTarget, config.Label) - alreadyLoaded, err := l.isBootstrapped(serviceTarget) + alreadyLoaded, err := l.isBootstrapped(ctx, serviceTarget) if err != nil { return err } @@ -202,7 +203,7 @@ func (l *launchd) Create(config *DaemonConfig, force bool) error { // Unload the service so we can reconfigure & reload if alreadyLoaded { - _, err = l.bootout(serviceTarget) + _, err = l.bootout(ctx, serviceTarget) if err != nil { return fmt.Errorf("could not bootout daemon process '%s': %w", config.Label, err) } @@ -217,7 +218,7 @@ func (l *launchd) Create(config *DaemonConfig, force bool) error { } } - err = l.bootstrapFile(domainTarget, filename) + err = l.bootstrapFile(ctx, domainTarget, filename) if err != nil { return fmt.Errorf("could not bootstrap daemon process '%s': %w", config.Label, err) } @@ -225,7 +226,7 @@ func (l *launchd) Create(config *DaemonConfig, force bool) error { return nil } -func (l *launchd) Start(label string) error { +func (l *launchd) Start(ctx context.Context, label string) error { user, err := l.user.CurrentUser() if err != nil { return fmt.Errorf("could not get current user for launchd service: %w", err) @@ -245,7 +246,7 @@ func (l *launchd) Start(label string) error { return nil } -func (l *launchd) Stop(label string) error { +func (l *launchd) Stop(ctx context.Context, label string) error { user, err := l.user.CurrentUser() if err != nil { return fmt.Errorf("could not get current user for launchd service: %w", err) @@ -268,7 +269,7 @@ func (l *launchd) Stop(label string) error { return nil } -func (l *launchd) Remove(label string) error { +func (l *launchd) Remove(ctx context.Context, label string) error { user, err := l.user.CurrentUser() if err != nil { return fmt.Errorf("could not get current user for launchd service: %w", err) @@ -278,7 +279,7 @@ func (l *launchd) Remove(label string) error { domainTarget := fmt.Sprintf(domainFormat, user.Uid) serviceTarget := fmt.Sprintf("%s/%s", domainTarget, label) - _, err = l.bootout(serviceTarget) + _, err = l.bootout(ctx, serviceTarget) if err != nil { return fmt.Errorf("could not remove daemon process '%s': %w", label, err) } diff --git a/internal/daemon/launchd_test.go b/internal/daemon/launchd_test.go index 59d1084..1ef6703 100644 --- a/internal/daemon/launchd_test.go +++ b/internal/daemon/launchd_test.go @@ -1,6 +1,7 @@ package daemon_test import ( + "context" "encoding/xml" "fmt" "os/user" @@ -237,6 +238,8 @@ func TestLaunchd_Create(t *testing.T) { testFileSystem := &MockFileSystem{} + ctx := context.Background() + launchd := daemon.NewLaunchdProvider(testUserProvider, testCommandExecutor, testFileSystem) // Verify launchd commands called @@ -276,7 +279,7 @@ func TestLaunchd_Create(t *testing.T) { } // Run "Create" - err := launchd.Create(tt.config, force) + err := launchd.Create(ctx, tt.config, force) // Assert on expected values if tt.expectErr { @@ -325,7 +328,7 @@ func TestLaunchd_Create(t *testing.T) { }), ).Return(nil).Once() - err := launchd.Create(tt.config, false) + err := launchd.Create(ctx, tt.config, false) assert.Nil(t, err) mock.AssertExpectationsForObjects(t, testCommandExecutor, testFileSystem) @@ -362,6 +365,8 @@ func TestLaunchd_Start(t *testing.T) { testCommandExecutor := &MockCommandExecutor{} + ctx := context.Background() + launchd := daemon.NewLaunchdProvider(testUserProvider, testCommandExecutor, nil) // Test #1: launchctl succeeds @@ -371,7 +376,7 @@ func TestLaunchd_Start(t *testing.T) { []string{"kickstart", fmt.Sprintf("user/123/%s", basicDaemonConfig.Label)}, ).Return(0, nil).Once() - err := launchd.Start(basicDaemonConfig.Label) + err := launchd.Start(ctx, basicDaemonConfig.Label) assert.Nil(t, err) mock.AssertExpectationsForObjects(t, testCommandExecutor) }) @@ -386,7 +391,7 @@ func TestLaunchd_Start(t *testing.T) { mock.AnythingOfType("[]string"), ).Return(1, nil).Once() - err := launchd.Start(basicDaemonConfig.Label) + err := launchd.Start(ctx, basicDaemonConfig.Label) assert.NotNil(t, err) mock.AssertExpectationsForObjects(t, testCommandExecutor) }) @@ -441,6 +446,8 @@ func TestLaunchd_Stop(t *testing.T) { testCommandExecutor := &MockCommandExecutor{} + ctx := context.Background() + launchd := daemon.NewLaunchdProvider(testUserProvider, testCommandExecutor, nil) for _, tt := range launchdStopTests { @@ -454,7 +461,7 @@ func TestLaunchd_Stop(t *testing.T) { } // Call function - err := launchd.Stop(tt.label) + err := launchd.Stop(ctx, tt.label) mock.AssertExpectationsForObjects(t, testCommandExecutor) if tt.expectErr { assert.NotNil(t, err) @@ -531,6 +538,8 @@ func TestLaunchd_Remove(t *testing.T) { testCommandExecutor := &MockCommandExecutor{} testFileSystem := &MockFileSystem{} + ctx := context.Background() + launchd := daemon.NewLaunchdProvider(testUserProvider, testCommandExecutor, testFileSystem) for _, tt := range launchdRemoveTests { @@ -552,7 +561,7 @@ func TestLaunchd_Remove(t *testing.T) { } // Call function - err := launchd.Remove(tt.label) + err := launchd.Remove(ctx, tt.label) mock.AssertExpectationsForObjects(t, testCommandExecutor) if tt.expectErr { assert.NotNil(t, err) diff --git a/internal/daemon/systemd.go b/internal/daemon/systemd.go index 545277e..0d22271 100644 --- a/internal/daemon/systemd.go +++ b/internal/daemon/systemd.go @@ -2,6 +2,7 @@ package daemon import ( "bytes" + "context" "fmt" "path/filepath" "strings" @@ -38,7 +39,7 @@ func NewSystemdProvider( } } -func (s *systemd) reloadDaemon() error { +func (s *systemd) reloadDaemon(ctx context.Context) error { exitCode, err := s.cmdExec.Run("systemctl", "--user", "daemon-reload") if err != nil { return err @@ -51,7 +52,7 @@ func (s *systemd) reloadDaemon() error { return nil } -func (s *systemd) Create(config *DaemonConfig, force bool) error { +func (s *systemd) Create(ctx context.Context, config *DaemonConfig, force bool) error { user, err := s.user.CurrentUser() if err != nil { return fmt.Errorf("could not get current user for systemd service: %w", err) @@ -89,7 +90,7 @@ func (s *systemd) Create(config *DaemonConfig, force bool) error { } // Reload the user-scoped service units after adding - err = s.reloadDaemon() + err = s.reloadDaemon(ctx) if err != nil { return err } @@ -97,7 +98,7 @@ func (s *systemd) Create(config *DaemonConfig, force bool) error { return nil } -func (s *systemd) Start(label string) error { +func (s *systemd) Start(ctx context.Context, label string) error { // TODO: warn user if already running exitCode, err := s.cmdExec.Run("systemctl", "--user", "start", label) if err != nil { @@ -111,7 +112,7 @@ func (s *systemd) Start(label string) error { return nil } -func (s *systemd) Stop(label string) error { +func (s *systemd) Stop(ctx context.Context, label string) error { // TODO: warn user if already stopped exitCode, err := s.cmdExec.Run("systemctl", "--user", "stop", label) if err != nil { @@ -125,7 +126,7 @@ func (s *systemd) Stop(label string) error { return nil } -func (s *systemd) Remove(label string) error { +func (s *systemd) Remove(ctx context.Context, label string) error { user, err := s.user.CurrentUser() if err != nil { return fmt.Errorf("could not get current user for launchd service: %w", err) @@ -138,7 +139,7 @@ func (s *systemd) Remove(label string) error { } // Reload the user-scoped service units after removing - err = s.reloadDaemon() + err = s.reloadDaemon(ctx) if err != nil { return err } diff --git a/internal/daemon/systemd_test.go b/internal/daemon/systemd_test.go index 0ecc11a..10fe837 100644 --- a/internal/daemon/systemd_test.go +++ b/internal/daemon/systemd_test.go @@ -1,6 +1,7 @@ package daemon_test import ( + "context" "fmt" "os/user" "path/filepath" @@ -137,6 +138,8 @@ func TestSystemd_Create(t *testing.T) { testFileSystem := &MockFileSystem{} + ctx := context.Background() + systemd := daemon.NewSystemdProvider(testUserProvider, testCommandExecutor, testFileSystem) for _, tt := range systemdCreateBehaviorTests { @@ -163,7 +166,7 @@ func TestSystemd_Create(t *testing.T) { } // Run "Create" - err := systemd.Create(tt.config, force) + err := systemd.Create(ctx, tt.config, force) // Assert on expected values if tt.expectErr { @@ -208,7 +211,7 @@ func TestSystemd_Create(t *testing.T) { }), ).Return(nil).Once() - err := systemd.Create(tt.config, false) + err := systemd.Create(ctx, tt.config, false) assert.Nil(t, err) mock.AssertExpectationsForObjects(t, testCommandExecutor, testFileSystem) @@ -243,6 +246,8 @@ func TestSystemd_Start(t *testing.T) { testCommandExecutor := &MockCommandExecutor{} + ctx := context.Background() + systemd := daemon.NewSystemdProvider(testUserProvider, testCommandExecutor, nil) // Test #1: systemctl succeeds @@ -252,7 +257,7 @@ func TestSystemd_Start(t *testing.T) { []string{"--user", "start", basicDaemonConfig.Label}, ).Return(0, nil).Once() - err := systemd.Start(basicDaemonConfig.Label) + err := systemd.Start(ctx, basicDaemonConfig.Label) assert.Nil(t, err) mock.AssertExpectationsForObjects(t, testCommandExecutor) }) @@ -267,7 +272,7 @@ func TestSystemd_Start(t *testing.T) { mock.AnythingOfType("[]string"), ).Return(1, nil).Once() - err := systemd.Start(basicDaemonConfig.Label) + err := systemd.Start(ctx, basicDaemonConfig.Label) assert.NotNil(t, err) mock.AssertExpectationsForObjects(t, testCommandExecutor) }) @@ -285,6 +290,8 @@ func TestSystemd_Stop(t *testing.T) { testCommandExecutor := &MockCommandExecutor{} + ctx := context.Background() + systemd := daemon.NewSystemdProvider(testUserProvider, testCommandExecutor, nil) // Test #1: systemctl succeeds @@ -294,7 +301,7 @@ func TestSystemd_Stop(t *testing.T) { []string{"--user", "stop", basicDaemonConfig.Label}, ).Return(0, nil).Once() - err := systemd.Stop(basicDaemonConfig.Label) + err := systemd.Stop(ctx, basicDaemonConfig.Label) assert.Nil(t, err) mock.AssertExpectationsForObjects(t, testCommandExecutor) }) @@ -309,7 +316,7 @@ func TestSystemd_Stop(t *testing.T) { mock.AnythingOfType("[]string"), ).Return(1, nil).Once() - err := systemd.Stop(basicDaemonConfig.Label) + err := systemd.Stop(ctx, basicDaemonConfig.Label) assert.NotNil(t, err) mock.AssertExpectationsForObjects(t, testCommandExecutor) }) @@ -324,7 +331,7 @@ func TestSystemd_Stop(t *testing.T) { mock.AnythingOfType("[]string"), ).Return(daemon.SystemdUnitNotInstalledErrorCode, nil).Once() - err := systemd.Stop(basicDaemonConfig.Label) + err := systemd.Stop(ctx, basicDaemonConfig.Label) assert.Nil(t, err) mock.AssertExpectationsForObjects(t, testCommandExecutor) }) @@ -379,6 +386,8 @@ func TestSystemd_Remove(t *testing.T) { testCommandExecutor := &MockCommandExecutor{} testFileSystem := &MockFileSystem{} + ctx := context.Background() + systemd := daemon.NewSystemdProvider(testUserProvider, testCommandExecutor, testFileSystem) for _, tt := range systemdRemoveTests { @@ -400,7 +409,7 @@ func TestSystemd_Remove(t *testing.T) { } // Call function - err := systemd.Remove(tt.label) + err := systemd.Remove(ctx, tt.label) mock.AssertExpectationsForObjects(t, testCommandExecutor) if tt.expectErr { assert.NotNil(t, err) From 13e2319428fe96459796c57509096b0196916a52 Mon Sep 17 00:00:00 2001 From: Victoria Dye Date: Thu, 23 Feb 2023 14:53:54 -0800 Subject: [PATCH 03/10] git-bundle-web-server: extract server setup into 'bundleWebServer' struct Move the server setup & execution logic out of 'main.go' and into a new 'bundleWebServer' struct in 'bundle.go'. This brings the web server more in line with 'git-bundle-server', which similarly uses its command structs to contain operational logic. This new design makes passing around the 'context' structure a bit more straightforward, and (eventually) helps avoid needing to pass a logger instance to each function called by 'git-bundle-web-server'. Signed-off-by: Victoria Dye --- cmd/git-bundle-web-server/bundle-server.go | 141 +++++++++++++++++++++ cmd/git-bundle-web-server/main.go | 117 +---------------- 2 files changed, 145 insertions(+), 113 deletions(-) create mode 100644 cmd/git-bundle-web-server/bundle-server.go diff --git a/cmd/git-bundle-web-server/bundle-server.go b/cmd/git-bundle-web-server/bundle-server.go new file mode 100644 index 0000000..e25cce2 --- /dev/null +++ b/cmd/git-bundle-web-server/bundle-server.go @@ -0,0 +1,141 @@ +package main + +import ( + "context" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "strings" + "sync" + "syscall" + + "github.com/github/git-bundle-server/internal/common" + "github.com/github/git-bundle-server/internal/core" +) + +type bundleWebServer struct { + server *http.Server + serverWaitGroup *sync.WaitGroup + listenAndServeFunc func() error +} + +func NewBundleWebServer(port string, certFile string, keyFile string) *bundleWebServer { + bundleServer := &bundleWebServer{ + serverWaitGroup: &sync.WaitGroup{}, + } + + // Configure the http.Server + mux := http.NewServeMux() + mux.HandleFunc("/", bundleServer.serve) + bundleServer.server = &http.Server{ + Handler: mux, + Addr: ":" + port, + } + + if certFile != "" { + bundleServer.listenAndServeFunc = func() error { return bundleServer.server.ListenAndServeTLS(certFile, keyFile) } + } else { + bundleServer.listenAndServeFunc = func() error { return bundleServer.server.ListenAndServe() } + } + + return bundleServer +} + +func (b *bundleWebServer) parseRoute(ctx context.Context, path string) (string, string, string, error) { + elements := strings.FieldsFunc(path, func(char rune) bool { return char == '/' }) + switch len(elements) { + case 0: + return "", "", "", fmt.Errorf("empty route") + case 1: + return "", "", "", fmt.Errorf("route has owner, but no repo") + case 2: + return elements[0], elements[1], "", nil + case 3: + return elements[0], elements[1], elements[2], nil + default: + return "", "", "", fmt.Errorf("path has depth exceeding three") + } +} + +func (b *bundleWebServer) serve(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + user, err := common.NewUserProvider().CurrentUser() + if err != nil { + return + } + fs := common.NewFileSystem() + path := r.URL.Path + + owner, repo, file, err := b.parseRoute(ctx, path) + if err != nil { + w.WriteHeader(http.StatusNotFound) + fmt.Printf("Failed to parse route: %s\n", err) + return + } + + route := owner + "/" + repo + + repos, err := core.GetRepositories(user, fs) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Printf("Failed to load routes\n") + return + } + + repository, contains := repos[route] + if !contains { + w.WriteHeader(http.StatusNotFound) + fmt.Printf("Failed to get route out of repos\n") + return + } + + if file == "" { + file = "bundle-list" + } + + fileToServe := repository.WebDir + "/" + file + data, err := os.ReadFile(fileToServe) + if err != nil { + w.WriteHeader(http.StatusNotFound) + fmt.Printf("Failed to read file\n") + return + } + + fmt.Printf("Successfully serving content for %s/%s\n", route, file) + w.Write(data) +} + +func (b *bundleWebServer) StartServerAsync(ctx context.Context) { + // Add to wait group + b.serverWaitGroup.Add(1) + + go func(ctx context.Context) { + defer b.serverWaitGroup.Done() + + // Return error unless it indicates graceful shutdown + err := b.listenAndServeFunc() + if err != nil && err != http.ErrServerClosed { + log.Fatal(err) + } + }(ctx) + + fmt.Println("Server is running at address " + b.server.Addr) +} + +func (b *bundleWebServer) HandleSignalsAsync(ctx context.Context) { + // Intercept interrupt signals + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + go func(ctx context.Context) { + <-c + fmt.Println("Starting graceful server shutdown...") + b.server.Shutdown(ctx) + }(ctx) +} + +func (b *bundleWebServer) Wait() { + b.serverWaitGroup.Wait() +} diff --git a/cmd/git-bundle-web-server/main.go b/cmd/git-bundle-web-server/main.go index 2f5589f..9d8887b 100644 --- a/cmd/git-bundle-web-server/main.go +++ b/cmd/git-bundle-web-server/main.go @@ -4,109 +4,12 @@ import ( "context" "flag" "fmt" - "log" - "net/http" "os" - "os/signal" - "strings" - "sync" - "syscall" "github.com/github/git-bundle-server/cmd/utils" "github.com/github/git-bundle-server/internal/argparse" - "github.com/github/git-bundle-server/internal/common" - "github.com/github/git-bundle-server/internal/core" ) -func parseRoute(path string) (string, string, string, error) { - elements := strings.FieldsFunc(path, func(char rune) bool { return char == '/' }) - switch len(elements) { - case 0: - return "", "", "", fmt.Errorf("empty route") - case 1: - return "", "", "", fmt.Errorf("route has owner, but no repo") - case 2: - return elements[0], elements[1], "", nil - case 3: - return elements[0], elements[1], elements[2], nil - default: - return "", "", "", fmt.Errorf("path has depth exceeding three") - } -} - -func serve(w http.ResponseWriter, r *http.Request) { - user, err := common.NewUserProvider().CurrentUser() - if err != nil { - return - } - fs := common.NewFileSystem() - path := r.URL.Path - - owner, repo, file, err := parseRoute(path) - if err != nil { - w.WriteHeader(http.StatusNotFound) - fmt.Printf("Failed to parse route: %s\n", err) - return - } - - route := owner + "/" + repo - - repos, err := core.GetRepositories(user, fs) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - fmt.Printf("Failed to load routes\n") - return - } - - repository, contains := repos[route] - if !contains { - w.WriteHeader(http.StatusNotFound) - fmt.Printf("Failed to get route out of repos\n") - return - } - - if file == "" { - file = "bundle-list" - } - - fileToServe := repository.WebDir + "/" + file - data, err := os.ReadFile(fileToServe) - if err != nil { - w.WriteHeader(http.StatusNotFound) - fmt.Printf("Failed to read file\n") - return - } - - fmt.Printf("Successfully serving content for %s/%s\n", route, file) - w.Write(data) -} - -func startServer(server *http.Server, - cert string, key string, - serverWaitGroup *sync.WaitGroup, -) { - // Add to wait group - serverWaitGroup.Add(1) - - go func() { - defer serverWaitGroup.Done() - - // Return error unless it indicates graceful shutdown - var err error - if cert != "" { - err = server.ListenAndServeTLS(cert, key) - } else { - err = server.ListenAndServe() - } - - if err != nil && err != http.ErrServerClosed { - log.Fatal(err) - } - }() - - fmt.Println("Server is running at address " + server.Addr) -} - func main() { ctx := context.Background() @@ -124,28 +27,16 @@ func main() { key := utils.GetFlagValue[string](parser, "key") // Configure the server - mux := http.NewServeMux() - mux.HandleFunc("/", serve) - server := &http.Server{ - Handler: mux, - Addr: ":" + port, - } - serverWaitGroup := &sync.WaitGroup{} + bundleServer := NewBundleWebServer(port, cert, key) // Start the server asynchronously - startServer(server, cert, key, serverWaitGroup) + bundleServer.StartServerAsync(ctx) // Intercept interrupt signals - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) - go func() { - <-c - fmt.Println("Starting graceful server shutdown...") - server.Shutdown(ctx) - }() + bundleServer.HandleSignalsAsync(ctx) // Wait for server to shut down - serverWaitGroup.Wait() + bundleServer.Wait() fmt.Println("Shutdown complete") } From 7f309c08d5befda65e339ed0cc6d1ff01ad666c7 Mon Sep 17 00:00:00 2001 From: Victoria Dye Date: Wed, 22 Feb 2023 11:17:12 -0800 Subject: [PATCH 04/10] git-bundle-server: add constructors for all commands In later patches, a logger struct will need to be a member of each command struct to facilitate structured logging & tracing. Like the 'web-server' command in ca23190 (bundle-server: add 'web-server' subcommand, 2023-01-05), add explicit constructors (i.e., 'New()') to the remaining commands. Signed-off-by: Victoria Dye --- cmd/git-bundle-server/delete.go | 12 ++++++++---- cmd/git-bundle-server/init.go | 12 ++++++++---- cmd/git-bundle-server/main.go | 12 ++++++------ cmd/git-bundle-server/start.go | 12 ++++++++---- cmd/git-bundle-server/stop.go | 12 ++++++++---- cmd/git-bundle-server/update-all.go | 12 ++++++++---- cmd/git-bundle-server/update.go | 12 ++++++++---- cmd/git-bundle-server/web-server.go | 20 ++++++++++---------- 8 files changed, 64 insertions(+), 40 deletions(-) diff --git a/cmd/git-bundle-server/delete.go b/cmd/git-bundle-server/delete.go index ca48ec0..4d4c279 100644 --- a/cmd/git-bundle-server/delete.go +++ b/cmd/git-bundle-server/delete.go @@ -8,19 +8,23 @@ import ( "github.com/github/git-bundle-server/internal/core" ) -type Delete struct{} +type deleteCmd struct{} -func (Delete) Name() string { +func NewDeleteCommand() argparse.Subcommand { + return &deleteCmd{} +} + +func (deleteCmd) Name() string { return "delete" } -func (Delete) Description() string { +func (deleteCmd) Description() string { return ` Remove the configuration for the given '' and delete its repository data.` } -func (Delete) Run(ctx context.Context, args []string) error { +func (deleteCmd) Run(ctx context.Context, args []string) error { parser := argparse.NewArgParser("git-bundle-server delete ") route := parser.PositionalString("route", "the route to delete") parser.Parse(ctx, args) diff --git a/cmd/git-bundle-server/init.go b/cmd/git-bundle-server/init.go index 7ee9e56..78f77b2 100644 --- a/cmd/git-bundle-server/init.go +++ b/cmd/git-bundle-server/init.go @@ -10,19 +10,23 @@ import ( "github.com/github/git-bundle-server/internal/git" ) -type Init struct{} +type initCmd struct{} -func (Init) Name() string { +func NewInitCommand() argparse.Subcommand { + return &initCmd{} +} + +func (initCmd) Name() string { return "init" } -func (Init) Description() string { +func (initCmd) Description() string { return ` Initialize a repository by cloning a bare repo from '', whose bundles should be hosted at ''.` } -func (Init) Run(ctx context.Context, args []string) error { +func (initCmd) Run(ctx context.Context, args []string) error { parser := argparse.NewArgParser("git-bundle-server init ") url := parser.PositionalString("url", "the URL of a repository to clone") // TODO: allow parsing out of diff --git a/cmd/git-bundle-server/main.go b/cmd/git-bundle-server/main.go index 1a0f64f..8ca9021 100644 --- a/cmd/git-bundle-server/main.go +++ b/cmd/git-bundle-server/main.go @@ -10,12 +10,12 @@ import ( func all() []argparse.Subcommand { return []argparse.Subcommand{ - Delete{}, - Init{}, - Start{}, - Stop{}, - Update{}, - UpdateAll{}, + NewDeleteCommand(), + NewInitCommand(), + NewStartCommand(), + NewStopCommand(), + NewUpdateCommand(), + NewUpdateAllCommand(), NewWebServerCommand(), } } diff --git a/cmd/git-bundle-server/start.go b/cmd/git-bundle-server/start.go index 6123bc9..b0f0a65 100644 --- a/cmd/git-bundle-server/start.go +++ b/cmd/git-bundle-server/start.go @@ -9,19 +9,23 @@ import ( "github.com/github/git-bundle-server/internal/core" ) -type Start struct{} +type startCmd struct{} -func (Start) Name() string { +func NewStartCommand() argparse.Subcommand { + return &startCmd{} +} + +func (startCmd) Name() string { return "start" } -func (Start) Description() string { +func (startCmd) Description() string { return ` Start computing bundles and serving content for the repository at the specified ''.` } -func (Start) Run(ctx context.Context, args []string) error { +func (startCmd) Run(ctx context.Context, args []string) error { parser := argparse.NewArgParser("git-bundle-server start ") route := parser.PositionalString("route", "the route for which bundles should be generated") parser.Parse(ctx, args) diff --git a/cmd/git-bundle-server/stop.go b/cmd/git-bundle-server/stop.go index 920ae54..fc95e04 100644 --- a/cmd/git-bundle-server/stop.go +++ b/cmd/git-bundle-server/stop.go @@ -7,19 +7,23 @@ import ( "github.com/github/git-bundle-server/internal/core" ) -type Stop struct{} +type stopCmd struct{} -func (Stop) Name() string { +func NewStopCommand() argparse.Subcommand { + return &stopCmd{} +} + +func (stopCmd) Name() string { return "stop" } -func (Stop) Description() string { +func (stopCmd) Description() string { return ` Stop computing bundles or serving content for the repository at the specified ''.` } -func (Stop) Run(ctx context.Context, args []string) error { +func (stopCmd) Run(ctx context.Context, args []string) error { parser := argparse.NewArgParser("git-bundle-server stop ") route := parser.PositionalString("route", "the route for which bundles should stop being generated") parser.Parse(ctx, args) diff --git a/cmd/git-bundle-server/update-all.go b/cmd/git-bundle-server/update-all.go index 79c9300..b714544 100644 --- a/cmd/git-bundle-server/update-all.go +++ b/cmd/git-bundle-server/update-all.go @@ -11,18 +11,22 @@ import ( "github.com/github/git-bundle-server/internal/core" ) -type UpdateAll struct{} +type updateAllCmd struct{} -func (UpdateAll) Name() string { +func NewUpdateAllCommand() argparse.Subcommand { + return &updateAllCmd{} +} + +func (updateAllCmd) Name() string { return "update-all" } -func (UpdateAll) Description() string { +func (updateAllCmd) Description() string { return ` For every configured route, run 'git-bundle-server update '.` } -func (UpdateAll) Run(ctx context.Context, args []string) error { +func (updateAllCmd) Run(ctx context.Context, args []string) error { user, err := common.NewUserProvider().CurrentUser() if err != nil { return err diff --git a/cmd/git-bundle-server/update.go b/cmd/git-bundle-server/update.go index bc52dcb..bfabca3 100644 --- a/cmd/git-bundle-server/update.go +++ b/cmd/git-bundle-server/update.go @@ -9,20 +9,24 @@ import ( "github.com/github/git-bundle-server/internal/core" ) -type Update struct{} +type updateCmd struct{} -func (Update) Name() string { +func NewUpdateCommand() argparse.Subcommand { + return &updateAllCmd{} +} + +func (updateCmd) Name() string { return "update" } -func (Update) Description() string { +func (updateCmd) Description() string { return ` For the repository in the current directory (or the one specified by ''), fetch the latest content from the remote, create a new set of bundles, and update the bundle list.` } -func (Update) Run(ctx context.Context, args []string) error { +func (updateCmd) Run(ctx context.Context, args []string) error { parser := argparse.NewArgParser("git-bundle-server update ") route := parser.PositionalString("route", "the route to update") parser.Parse(ctx, args) diff --git a/cmd/git-bundle-server/web-server.go b/cmd/git-bundle-server/web-server.go index 6130673..c6fbb49 100644 --- a/cmd/git-bundle-server/web-server.go +++ b/cmd/git-bundle-server/web-server.go @@ -15,30 +15,30 @@ import ( "github.com/github/git-bundle-server/internal/daemon" ) -type webServer struct { +type webServerCmd struct { user common.UserProvider cmdExec common.CommandExecutor fileSystem common.FileSystem } -func NewWebServerCommand() *webServer { - // Create dependencies - return &webServer{ +func NewWebServerCommand() argparse.Subcommand { + // Create subcommand-specific dependencies + return &webServerCmd{ user: common.NewUserProvider(), cmdExec: common.NewCommandExecutor(), fileSystem: common.NewFileSystem(), } } -func (webServer) Name() string { +func (webServerCmd) Name() string { return "web-server" } -func (webServer) Description() string { +func (webServerCmd) Description() string { return `Manage the web server hosting bundle content` } -func (w *webServer) getDaemonConfig() (*daemon.DaemonConfig, error) { +func (w *webServerCmd) getDaemonConfig() (*daemon.DaemonConfig, error) { // Find git-bundle-web-server // First, search for it on the path programPath, err := exec.LookPath("git-bundle-web-server") @@ -77,7 +77,7 @@ func (w *webServer) getDaemonConfig() (*daemon.DaemonConfig, error) { }, nil } -func (w *webServer) startServer(ctx context.Context, args []string) error { +func (w *webServerCmd) startServer(ctx context.Context, args []string) error { // Parse subcommand arguments parser := argparse.NewArgParser("git-bundle-server web-server start [-f|--force]") @@ -144,7 +144,7 @@ func (w *webServer) startServer(ctx context.Context, args []string) error { return nil } -func (w *webServer) stopServer(ctx context.Context, args []string) error { +func (w *webServerCmd) stopServer(ctx context.Context, args []string) error { // Parse subcommand arguments parser := argparse.NewArgParser("git-bundle-server web-server stop [--remove]") remove := parser.Bool("remove", false, "Remove the web server daemon configuration from the system after stopping") @@ -175,7 +175,7 @@ func (w *webServer) stopServer(ctx context.Context, args []string) error { return nil } -func (w *webServer) Run(ctx context.Context, args []string) error { +func (w *webServerCmd) Run(ctx context.Context, args []string) error { // Parse command arguments parser := argparse.NewArgParser("git-bundle-server web-server (start|stop) ") parser.Subcommand(argparse.NewSubcommand("start", "Start the web server", w.startServer)) From 4aea0a31838e47fd7539a1be0c18b0bea7139184 Mon Sep 17 00:00:00 2001 From: Victoria Dye Date: Thu, 23 Feb 2023 13:15:23 -0800 Subject: [PATCH 05/10] log: create initial trace2 logger Create a general 'TraceLogger' interface and (for now) a sole implementation in the 'Trace2' struct. There are two main components of this commit: the Trace2 logger itself, and how it's used in the CLIs built in this repository. The Trace2 logger is built around the 'zap' structured logger. As of right now, it does not log anything other than 'start', 'exit', and 'atexit' (and only through the 'WithLogger()' wrapper described later). The logs include: * 'event': the event type; it takes over the 'message' field of the 'zap' logging functions, mostly for the sake of simplicity (those functions require a message, and every trace2 log has an 'event'). * 'sid': a UUID generated by the 'uuid.NewUUID()' function, propagated through the call stack via the 'context.Context'. * 'time': the UTC timestamp of the log generation; this is automatically provided by a 'zap' logger, but requires a custom formatter to provide the timestamp in UTC. * 't_abs': the time (in seconds) since the start of the program (technically, when the 'log' package is loaded and the 'globalStartTime' is initialized). * 'file' and 'line': the filename and line number where the relevant trace2 log function was called. * 'thread': hardcoded to 'main', but may be updated in the future to give unique names to goroutines. From a user's perspective, the logger is configured by the 'GIT_TRACE2_EVENT' environment variable. Its behavior matches what is described in Git [1], but it is currently missing support for open file handles 2-9 and Unix Domain Sockets. To use the trace2 logger, create the 'main()' function wrapper 'WithLogger()'. This function creates an instance of a 'TraceLogger', logs the program start, defers a "cleanup" logging function to write out the 'exit'/'atexit' logs, then runs the wrapped function. [1] https://git-scm.com/docs/git-config#Documentation/git-config.txt-trace2eventTarget Signed-off-by: Victoria Dye --- cmd/git-bundle-server/main.go | 27 ++--- cmd/git-bundle-web-server/main.go | 44 ++++---- go.mod | 8 +- go.sum | 11 ++ internal/log/logger.go | 42 ++++++++ internal/log/trace2.go | 167 ++++++++++++++++++++++++++++++ 6 files changed, 264 insertions(+), 35 deletions(-) create mode 100644 internal/log/logger.go create mode 100644 internal/log/trace2.go diff --git a/cmd/git-bundle-server/main.go b/cmd/git-bundle-server/main.go index 8ca9021..56cfea3 100644 --- a/cmd/git-bundle-server/main.go +++ b/cmd/git-bundle-server/main.go @@ -6,6 +6,7 @@ import ( "os" "github.com/github/git-bundle-server/internal/argparse" + tracelog "github.com/github/git-bundle-server/internal/log" ) func all() []argparse.Subcommand { @@ -21,19 +22,19 @@ func all() []argparse.Subcommand { } func main() { - ctx := context.Background() + tracelog.WithTraceLogger(context.Background(), func(ctx context.Context, logger tracelog.TraceLogger) { + cmds := all() - cmds := all() + parser := argparse.NewArgParser("git-bundle-server []") + parser.SetIsTopLevel(true) + for _, cmd := range cmds { + parser.Subcommand(cmd) + } + parser.Parse(ctx, os.Args[1:]) - parser := argparse.NewArgParser("git-bundle-server []") - parser.SetIsTopLevel(true) - for _, cmd := range cmds { - parser.Subcommand(cmd) - } - parser.Parse(ctx, os.Args[1:]) - - err := parser.InvokeSubcommand(ctx) - if err != nil { - log.Fatal("Failed with error: ", err) - } + err := parser.InvokeSubcommand(ctx) + if err != nil { + log.Fatalf("Failed with error: %s", err) + } + }) } diff --git a/cmd/git-bundle-web-server/main.go b/cmd/git-bundle-web-server/main.go index 9d8887b..1e37b96 100644 --- a/cmd/git-bundle-web-server/main.go +++ b/cmd/git-bundle-web-server/main.go @@ -8,35 +8,37 @@ import ( "github.com/github/git-bundle-server/cmd/utils" "github.com/github/git-bundle-server/internal/argparse" + "github.com/github/git-bundle-server/internal/log" ) func main() { - ctx := context.Background() + log.WithTraceLogger(context.Background(), func(ctx context.Context, logger log.TraceLogger) { + parser := argparse.NewArgParser("git-bundle-web-server [--port ] [--cert --key ]") + flags, validate := utils.WebServerFlags(parser) + flags.VisitAll(func(f *flag.Flag) { + parser.Var(f.Value, f.Name, f.Usage) + }) - parser := argparse.NewArgParser("git-bundle-web-server [--port ] [--cert --key ]") - flags, validate := utils.WebServerFlags(parser) - flags.VisitAll(func(f *flag.Flag) { - parser.Var(f.Value, f.Name, f.Usage) - }) - parser.Parse(ctx, os.Args[1:]) - validate(ctx) + parser.Parse(ctx, os.Args[1:]) + validate(ctx) - // Get the flag values - port := utils.GetFlagValue[string](parser, "port") - cert := utils.GetFlagValue[string](parser, "cert") - key := utils.GetFlagValue[string](parser, "key") + // Get the flag values + port := utils.GetFlagValue[string](parser, "port") + cert := utils.GetFlagValue[string](parser, "cert") + key := utils.GetFlagValue[string](parser, "key") - // Configure the server - bundleServer := NewBundleWebServer(port, cert, key) + // Configure the server + bundleServer := NewBundleWebServer(port, cert, key) - // Start the server asynchronously - bundleServer.StartServerAsync(ctx) + // Start the server asynchronously + bundleServer.StartServerAsync(ctx) - // Intercept interrupt signals - bundleServer.HandleSignalsAsync(ctx) + // Intercept interrupt signals + bundleServer.HandleSignalsAsync(ctx) - // Wait for server to shut down - bundleServer.Wait() + // Wait for server to shut down + bundleServer.Wait() - fmt.Println("Shutdown complete") + fmt.Println("Shutdown complete") + }) } diff --git a/go.mod b/go.mod index e0bc362..82bb106 100644 --- a/go.mod +++ b/go.mod @@ -2,11 +2,17 @@ module github.com/github/git-bundle-server go 1.19 -require github.com/stretchr/testify v1.8.1 +require ( + github.com/google/uuid v1.3.0 + github.com/stretchr/testify v1.8.1 + go.uber.org/zap v1.24.0 +) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.5.0 // indirect + go.uber.org/atomic v1.10.0 // indirect + go.uber.org/multierr v1.9.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 231c467..89383a5 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,10 @@ +github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -11,6 +15,13 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= +go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= +go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= +go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60= +go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/log/logger.go b/internal/log/logger.go new file mode 100644 index 0000000..cf062ea --- /dev/null +++ b/internal/log/logger.go @@ -0,0 +1,42 @@ +package log + +import ( + "context" + "fmt" + "os" + "runtime/debug" +) + +type TraceLogger interface{} + +type traceLoggerInternal interface { + // Internal setup/teardown functions + logStart(ctx context.Context) context.Context + logExit(ctx context.Context, exitCode int) + + TraceLogger +} + +func WithTraceLogger( + ctx context.Context, + mainFunc func(context.Context, TraceLogger), +) { + logger := NewTrace2() + + // Set up the program-level context + ctx = logger.logStart(ctx) + defer func() { + if panicInfo := recover(); panicInfo != nil { + // Panicking - log, print panic info, then exit + logger.logExit(ctx, 1) + os.Stderr.WriteString(fmt.Sprintf("panic: %s\n\n", panicInfo)) + debug.PrintStack() + os.Exit(1) + } else { + // Just log the exit (but don't os.Exit()) so we can exit normally + logger.logExit(ctx, 0) + } + }() + + mainFunc(ctx, logger) +} diff --git a/internal/log/trace2.go b/internal/log/trace2.go new file mode 100644 index 0000000..ddad24f --- /dev/null +++ b/internal/log/trace2.go @@ -0,0 +1,167 @@ +package log + +import ( + "context" + "fmt" + "os" + "path" + "path/filepath" + "runtime" + "strconv" + "time" + + "github.com/google/uuid" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +// Trace2 environment variables +const ( + // TODO: handle GIT_TRACE2 by adding a separate output config (see zapcore + // "AdvancedConfiguration" example: + // https://pkg.go.dev/go.uber.org/zap#example-package-AdvancedConfiguration) + trace2Event string = "GIT_TRACE2_EVENT" +) + +// Global start time +var globalStart = time.Now().UTC() + +const trace2TimeFormat string = "2006-01-02T15:04:05.999999Z" + +type ctxKey int + +const ( + sidId ctxKey = iota +) + +type Trace2 struct { + logger *zap.Logger +} + +func getTrace2OutputPaths(envKey string) []string { + tr2Output := os.Getenv(envKey) + + // Configure the output + if tr2, err := strconv.Atoi(tr2Output); err == nil { + // Handle numeric values + if tr2 == 1 { + return []string{"stderr"} + } + // TODO: handle file handles 2-9 and unix sockets + } else if tr2Output != "" { + // Assume we received a path + fileInfo, err := os.Stat(tr2Output) + if err == nil && fileInfo.IsDir() { + // If the path is an existing directory, generate a filename + return []string{ + filepath.Join(tr2Output, fmt.Sprintf("trace2_%s.txt", globalStart.Format(trace2TimeFormat))), + } + } else { + // Create leading directories + parentDir := path.Dir(tr2Output) + os.MkdirAll(parentDir, 0o755) + return []string{tr2Output} + } + } + + return []string{} +} + +func createTrace2ZapLogger() *zap.Logger { + loggerConfig := zap.NewProductionConfig() + + // Configure the output for GIT_TRACE2_EVENT + loggerConfig.OutputPaths = getTrace2OutputPaths(trace2Event) + loggerConfig.Level = zap.NewAtomicLevelAt(zap.DebugLevel) + + // Encode UTC time + loggerConfig.EncoderConfig.TimeKey = "time" + loggerConfig.EncoderConfig.EncodeTime = zapcore.TimeEncoder( + func(t time.Time, enc zapcore.PrimitiveArrayEncoder) { + enc.AppendString(t.Format(trace2TimeFormat)) + }, + ) + + // Re-purpose the "message" to represent the (always-present) "event" key + loggerConfig.EncoderConfig.MessageKey = "event" + + // Don't print the log level + loggerConfig.EncoderConfig.LevelKey = "" + + // Disable caller info, we'll customize those fields manually + logger, _ := loggerConfig.Build(zap.WithCaller(false)) + return logger +} + +func NewTrace2() traceLoggerInternal { + return &Trace2{ + logger: createTrace2ZapLogger(), + } +} + +type fieldList []zap.Field + +func (l fieldList) withTime() fieldList { + return append(l, zap.Float64("t_abs", time.Since(globalStart).Seconds())) +} + +func (l fieldList) with(f ...zap.Field) fieldList { + return append(l, f...) +} + +func (t *Trace2) sharedFields(ctx context.Context) (context.Context, fieldList) { + fields := fieldList{} + + // Get the session ID + var sid uuid.UUID + haveSid := false + sidAny := ctx.Value(sidId) + if sidAny != nil { + sid, haveSid = sidAny.(uuid.UUID) + } + if !haveSid { + sid = uuid.New() + ctx = context.WithValue(ctx, sidId, sid) + } + fields = append(fields, zap.String("sid", sid.String())) + + // Hardcode the thread to "main" because Go doesn't like to share its + // internal info about threading. + fields = append(fields, zap.String("thread", "main")) + + // Get the caller of the function in trace2.go + // Skip up two levels: + // 0: this function + // 1: the caller of this function (StartTrace, LogEvent, etc.) + // 2: the function calling this trace2 library + _, fileName, lineNum, ok := runtime.Caller(2) + if ok { + fields = append(fields, + zap.String("file", filepath.Base(fileName)), + zap.Int("line", lineNum), + ) + } + + return ctx, fields +} + +func (t *Trace2) logStart(ctx context.Context) context.Context { + ctx, sharedFields := t.sharedFields(ctx) + + t.logger.Info("start", sharedFields.withTime().with( + zap.Strings("argv", os.Args), + )...) + + return ctx +} + +func (t *Trace2) logExit(ctx context.Context, exitCode int) { + _, sharedFields := t.sharedFields(ctx) + fields := sharedFields.with( + zap.Int("code", exitCode), + ) + t.logger.Info("exit", fields.withTime()...) + t.logger.Info("atexit", fields.withTime()...) + + t.logger.Sync() +} From ce15989ba1511848310850695ddff41259bd38d5 Mon Sep 17 00:00:00 2001 From: Victoria Dye Date: Thu, 23 Feb 2023 14:27:34 -0800 Subject: [PATCH 06/10] log: propagate to deeper structures To increase the amount of the code covered by logging, pass the 'main()' functions' 'logger' instances to structures it creates to execute the user-specified operation (i.e., the arg parser, 'git-bundle-server' commands, 'bundleWebServer', and daemon providers). For now, do nothing with the 'logger' instance in those structs. Signed-off-by: Victoria Dye --- cmd/git-bundle-server/delete.go | 15 ++++++++++----- cmd/git-bundle-server/init.go | 15 ++++++++++----- cmd/git-bundle-server/main.go | 20 ++++++++++---------- cmd/git-bundle-server/start.go | 15 ++++++++++----- cmd/git-bundle-server/stop.go | 15 ++++++++++----- cmd/git-bundle-server/update-all.go | 15 ++++++++++----- cmd/git-bundle-server/update.go | 15 ++++++++++----- cmd/git-bundle-server/web-server.go | 15 +++++++++------ cmd/git-bundle-web-server/bundle-server.go | 7 ++++++- cmd/git-bundle-web-server/main.go | 4 ++-- internal/argparse/argparse.go | 6 +++++- internal/daemon/daemon.go | 6 ++++-- internal/daemon/launchd.go | 4 ++++ internal/daemon/launchd_test.go | 8 ++++---- internal/daemon/systemd.go | 4 ++++ internal/daemon/systemd_test.go | 8 ++++---- 16 files changed, 112 insertions(+), 60 deletions(-) diff --git a/cmd/git-bundle-server/delete.go b/cmd/git-bundle-server/delete.go index 4d4c279..270b11d 100644 --- a/cmd/git-bundle-server/delete.go +++ b/cmd/git-bundle-server/delete.go @@ -6,12 +6,17 @@ import ( "github.com/github/git-bundle-server/internal/argparse" "github.com/github/git-bundle-server/internal/core" + "github.com/github/git-bundle-server/internal/log" ) -type deleteCmd struct{} +type deleteCmd struct { + logger log.TraceLogger +} -func NewDeleteCommand() argparse.Subcommand { - return &deleteCmd{} +func NewDeleteCommand(logger log.TraceLogger) argparse.Subcommand { + return &deleteCmd{ + logger: logger, + } } func (deleteCmd) Name() string { @@ -24,8 +29,8 @@ Remove the configuration for the given '' and delete its repository data.` } -func (deleteCmd) Run(ctx context.Context, args []string) error { - parser := argparse.NewArgParser("git-bundle-server delete ") +func (d *deleteCmd) Run(ctx context.Context, args []string) error { + parser := argparse.NewArgParser(d.logger, "git-bundle-server delete ") route := parser.PositionalString("route", "the route to delete") parser.Parse(ctx, args) diff --git a/cmd/git-bundle-server/init.go b/cmd/git-bundle-server/init.go index 78f77b2..bd3c773 100644 --- a/cmd/git-bundle-server/init.go +++ b/cmd/git-bundle-server/init.go @@ -8,12 +8,17 @@ import ( "github.com/github/git-bundle-server/internal/bundles" "github.com/github/git-bundle-server/internal/core" "github.com/github/git-bundle-server/internal/git" + "github.com/github/git-bundle-server/internal/log" ) -type initCmd struct{} +type initCmd struct { + logger log.TraceLogger +} -func NewInitCommand() argparse.Subcommand { - return &initCmd{} +func NewInitCommand(logger log.TraceLogger) argparse.Subcommand { + return &initCmd{ + logger: logger, + } } func (initCmd) Name() string { @@ -26,8 +31,8 @@ Initialize a repository by cloning a bare repo from '', whose bundles should be hosted at ''.` } -func (initCmd) Run(ctx context.Context, args []string) error { - parser := argparse.NewArgParser("git-bundle-server init ") +func (i *initCmd) Run(ctx context.Context, args []string) error { + parser := argparse.NewArgParser(i.logger, "git-bundle-server init ") url := parser.PositionalString("url", "the URL of a repository to clone") // TODO: allow parsing out of route := parser.PositionalString("route", "the route to host the specified repo") diff --git a/cmd/git-bundle-server/main.go b/cmd/git-bundle-server/main.go index 56cfea3..0654a9e 100644 --- a/cmd/git-bundle-server/main.go +++ b/cmd/git-bundle-server/main.go @@ -9,23 +9,23 @@ import ( tracelog "github.com/github/git-bundle-server/internal/log" ) -func all() []argparse.Subcommand { +func all(logger tracelog.TraceLogger) []argparse.Subcommand { return []argparse.Subcommand{ - NewDeleteCommand(), - NewInitCommand(), - NewStartCommand(), - NewStopCommand(), - NewUpdateCommand(), - NewUpdateAllCommand(), - NewWebServerCommand(), + NewDeleteCommand(logger), + NewInitCommand(logger), + NewStartCommand(logger), + NewStopCommand(logger), + NewUpdateCommand(logger), + NewUpdateAllCommand(logger), + NewWebServerCommand(logger), } } func main() { tracelog.WithTraceLogger(context.Background(), func(ctx context.Context, logger tracelog.TraceLogger) { - cmds := all() + cmds := all(logger) - parser := argparse.NewArgParser("git-bundle-server []") + parser := argparse.NewArgParser(logger, "git-bundle-server []") parser.SetIsTopLevel(true) for _, cmd := range cmds { parser.Subcommand(cmd) diff --git a/cmd/git-bundle-server/start.go b/cmd/git-bundle-server/start.go index b0f0a65..b26f244 100644 --- a/cmd/git-bundle-server/start.go +++ b/cmd/git-bundle-server/start.go @@ -7,12 +7,17 @@ import ( "github.com/github/git-bundle-server/internal/argparse" "github.com/github/git-bundle-server/internal/core" + "github.com/github/git-bundle-server/internal/log" ) -type startCmd struct{} +type startCmd struct { + logger log.TraceLogger +} -func NewStartCommand() argparse.Subcommand { - return &startCmd{} +func NewStartCommand(logger log.TraceLogger) argparse.Subcommand { + return &startCmd{ + logger: logger, + } } func (startCmd) Name() string { @@ -25,8 +30,8 @@ Start computing bundles and serving content for the repository at the specified ''.` } -func (startCmd) Run(ctx context.Context, args []string) error { - parser := argparse.NewArgParser("git-bundle-server start ") +func (s *startCmd) Run(ctx context.Context, args []string) error { + parser := argparse.NewArgParser(s.logger, "git-bundle-server start ") route := parser.PositionalString("route", "the route for which bundles should be generated") parser.Parse(ctx, args) diff --git a/cmd/git-bundle-server/stop.go b/cmd/git-bundle-server/stop.go index fc95e04..c051dcc 100644 --- a/cmd/git-bundle-server/stop.go +++ b/cmd/git-bundle-server/stop.go @@ -5,12 +5,17 @@ import ( "github.com/github/git-bundle-server/internal/argparse" "github.com/github/git-bundle-server/internal/core" + "github.com/github/git-bundle-server/internal/log" ) -type stopCmd struct{} +type stopCmd struct { + logger log.TraceLogger +} -func NewStopCommand() argparse.Subcommand { - return &stopCmd{} +func NewStopCommand(logger log.TraceLogger) argparse.Subcommand { + return &stopCmd{ + logger: logger, + } } func (stopCmd) Name() string { @@ -23,8 +28,8 @@ Stop computing bundles or serving content for the repository at the specified ''.` } -func (stopCmd) Run(ctx context.Context, args []string) error { - parser := argparse.NewArgParser("git-bundle-server stop ") +func (s *stopCmd) Run(ctx context.Context, args []string) error { + parser := argparse.NewArgParser(s.logger, "git-bundle-server stop ") route := parser.PositionalString("route", "the route for which bundles should stop being generated") parser.Parse(ctx, args) diff --git a/cmd/git-bundle-server/update-all.go b/cmd/git-bundle-server/update-all.go index b714544..1322a01 100644 --- a/cmd/git-bundle-server/update-all.go +++ b/cmd/git-bundle-server/update-all.go @@ -9,12 +9,17 @@ import ( "github.com/github/git-bundle-server/internal/argparse" "github.com/github/git-bundle-server/internal/common" "github.com/github/git-bundle-server/internal/core" + "github.com/github/git-bundle-server/internal/log" ) -type updateAllCmd struct{} +type updateAllCmd struct { + logger log.TraceLogger +} -func NewUpdateAllCommand() argparse.Subcommand { - return &updateAllCmd{} +func NewUpdateAllCommand(logger log.TraceLogger) argparse.Subcommand { + return &updateAllCmd{ + logger: logger, + } } func (updateAllCmd) Name() string { @@ -26,14 +31,14 @@ func (updateAllCmd) Description() string { For every configured route, run 'git-bundle-server update '.` } -func (updateAllCmd) Run(ctx context.Context, args []string) error { +func (u *updateAllCmd) Run(ctx context.Context, args []string) error { user, err := common.NewUserProvider().CurrentUser() if err != nil { return err } fs := common.NewFileSystem() - parser := argparse.NewArgParser("git-bundle-server update-all") + parser := argparse.NewArgParser(u.logger, "git-bundle-server update-all") parser.Parse(ctx, args) exe, err := os.Executable() diff --git a/cmd/git-bundle-server/update.go b/cmd/git-bundle-server/update.go index bfabca3..acc9709 100644 --- a/cmd/git-bundle-server/update.go +++ b/cmd/git-bundle-server/update.go @@ -7,12 +7,17 @@ import ( "github.com/github/git-bundle-server/internal/argparse" "github.com/github/git-bundle-server/internal/bundles" "github.com/github/git-bundle-server/internal/core" + "github.com/github/git-bundle-server/internal/log" ) -type updateCmd struct{} +type updateCmd struct { + logger log.TraceLogger +} -func NewUpdateCommand() argparse.Subcommand { - return &updateAllCmd{} +func NewUpdateCommand(logger log.TraceLogger) argparse.Subcommand { + return &updateCmd{ + logger: logger, + } } func (updateCmd) Name() string { @@ -26,8 +31,8 @@ For the repository in the current directory (or the one specified by bundles, and update the bundle list.` } -func (updateCmd) Run(ctx context.Context, args []string) error { - parser := argparse.NewArgParser("git-bundle-server update ") +func (u *updateCmd) Run(ctx context.Context, args []string) error { + parser := argparse.NewArgParser(u.logger, "git-bundle-server update ") route := parser.PositionalString("route", "the route to update") parser.Parse(ctx, args) diff --git a/cmd/git-bundle-server/web-server.go b/cmd/git-bundle-server/web-server.go index c6fbb49..677782c 100644 --- a/cmd/git-bundle-server/web-server.go +++ b/cmd/git-bundle-server/web-server.go @@ -13,17 +13,20 @@ import ( "github.com/github/git-bundle-server/internal/argparse" "github.com/github/git-bundle-server/internal/common" "github.com/github/git-bundle-server/internal/daemon" + "github.com/github/git-bundle-server/internal/log" ) type webServerCmd struct { + logger log.TraceLogger user common.UserProvider cmdExec common.CommandExecutor fileSystem common.FileSystem } -func NewWebServerCommand() argparse.Subcommand { +func NewWebServerCommand(logger log.TraceLogger) argparse.Subcommand { // Create subcommand-specific dependencies return &webServerCmd{ + logger: logger, user: common.NewUserProvider(), cmdExec: common.NewCommandExecutor(), fileSystem: common.NewFileSystem(), @@ -79,7 +82,7 @@ func (w *webServerCmd) getDaemonConfig() (*daemon.DaemonConfig, error) { func (w *webServerCmd) startServer(ctx context.Context, args []string) error { // Parse subcommand arguments - parser := argparse.NewArgParser("git-bundle-server web-server start [-f|--force]") + parser := argparse.NewArgParser(w.logger, "git-bundle-server web-server start [-f|--force]") // Args for 'git-bundle-server web-server start' force := parser.Bool("force", false, "Force reconfiguration of the web server daemon") @@ -94,7 +97,7 @@ func (w *webServerCmd) startServer(ctx context.Context, args []string) error { parser.Parse(ctx, args) validate(ctx) - d, err := daemon.NewDaemonProvider(w.user, w.cmdExec, w.fileSystem) + d, err := daemon.NewDaemonProvider(w.logger, w.user, w.cmdExec, w.fileSystem) if err != nil { return err } @@ -146,11 +149,11 @@ func (w *webServerCmd) startServer(ctx context.Context, args []string) error { func (w *webServerCmd) stopServer(ctx context.Context, args []string) error { // Parse subcommand arguments - parser := argparse.NewArgParser("git-bundle-server web-server stop [--remove]") + parser := argparse.NewArgParser(w.logger, "git-bundle-server web-server stop [--remove]") remove := parser.Bool("remove", false, "Remove the web server daemon configuration from the system after stopping") parser.Parse(ctx, args) - d, err := daemon.NewDaemonProvider(w.user, w.cmdExec, w.fileSystem) + d, err := daemon.NewDaemonProvider(w.logger, w.user, w.cmdExec, w.fileSystem) if err != nil { return err } @@ -177,7 +180,7 @@ func (w *webServerCmd) stopServer(ctx context.Context, args []string) error { func (w *webServerCmd) Run(ctx context.Context, args []string) error { // Parse command arguments - parser := argparse.NewArgParser("git-bundle-server web-server (start|stop) ") + parser := argparse.NewArgParser(w.logger, "git-bundle-server web-server (start|stop) ") parser.Subcommand(argparse.NewSubcommand("start", "Start the web server", w.startServer)) parser.Subcommand(argparse.NewSubcommand("stop", "Stop the web server", w.stopServer)) parser.Parse(ctx, args) diff --git a/cmd/git-bundle-web-server/bundle-server.go b/cmd/git-bundle-web-server/bundle-server.go index e25cce2..db923ae 100644 --- a/cmd/git-bundle-web-server/bundle-server.go +++ b/cmd/git-bundle-web-server/bundle-server.go @@ -13,16 +13,21 @@ import ( "github.com/github/git-bundle-server/internal/common" "github.com/github/git-bundle-server/internal/core" + tracelog "github.com/github/git-bundle-server/internal/log" ) type bundleWebServer struct { + logger tracelog.TraceLogger server *http.Server serverWaitGroup *sync.WaitGroup listenAndServeFunc func() error } -func NewBundleWebServer(port string, certFile string, keyFile string) *bundleWebServer { +func NewBundleWebServer(logger tracelog.TraceLogger, + port string, certFile string, keyFile string, +) *bundleWebServer { bundleServer := &bundleWebServer{ + logger: logger, serverWaitGroup: &sync.WaitGroup{}, } diff --git a/cmd/git-bundle-web-server/main.go b/cmd/git-bundle-web-server/main.go index 1e37b96..be4bd78 100644 --- a/cmd/git-bundle-web-server/main.go +++ b/cmd/git-bundle-web-server/main.go @@ -13,7 +13,7 @@ import ( func main() { log.WithTraceLogger(context.Background(), func(ctx context.Context, logger log.TraceLogger) { - parser := argparse.NewArgParser("git-bundle-web-server [--port ] [--cert --key ]") + parser := argparse.NewArgParser(logger, "git-bundle-web-server [--port ] [--cert --key ]") flags, validate := utils.WebServerFlags(parser) flags.VisitAll(func(f *flag.Flag) { parser.Var(f.Value, f.Name, f.Usage) @@ -28,7 +28,7 @@ func main() { key := utils.GetFlagValue[string](parser, "key") // Configure the server - bundleServer := NewBundleWebServer(port, cert, key) + bundleServer := NewBundleWebServer(logger, port, cert, key) // Start the server asynchronously bundleServer.StartServerAsync(ctx) diff --git a/internal/argparse/argparse.go b/internal/argparse/argparse.go index 1f42d7e..aae5fe5 100644 --- a/internal/argparse/argparse.go +++ b/internal/argparse/argparse.go @@ -6,6 +6,8 @@ import ( "fmt" "os" "strings" + + "github.com/github/git-bundle-server/internal/log" ) // For consistency with 'flag', use 2 as the usage-related error code @@ -30,10 +32,11 @@ type argParser struct { // Post-parsing selectedSubcommand Subcommand + logger log.TraceLogger flag.FlagSet } -func NewArgParser(usageString string) *argParser { +func NewArgParser(logger log.TraceLogger, usageString string) *argParser { flagSet := flag.NewFlagSet("", flag.ContinueOnError) a := &argParser{ @@ -41,6 +44,7 @@ func NewArgParser(usageString string) *argParser { parsed: false, argOffset: 0, subcommands: make(map[string]Subcommand), + logger: logger, FlagSet: *flagSet, } diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 7d5d1ba..c1bd31c 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -6,6 +6,7 @@ import ( "runtime" "github.com/github/git-bundle-server/internal/common" + "github.com/github/git-bundle-server/internal/log" ) type DaemonConfig struct { @@ -26,6 +27,7 @@ type DaemonProvider interface { } func NewDaemonProvider( + l log.TraceLogger, u common.UserProvider, c common.CommandExecutor, fs common.FileSystem, @@ -33,10 +35,10 @@ func NewDaemonProvider( switch thisOs := runtime.GOOS; thisOs { case "linux": // Use systemd/systemctl - return NewSystemdProvider(u, c, fs), nil + return NewSystemdProvider(l, u, c, fs), nil case "darwin": // Use launchd/launchctl - return NewLaunchdProvider(u, c, fs), nil + return NewLaunchdProvider(l, u, c, fs), nil default: return nil, fmt.Errorf("cannot configure daemon handler for OS '%s'", thisOs) } diff --git a/internal/daemon/launchd.go b/internal/daemon/launchd.go index 2d4a12c..df87d6f 100644 --- a/internal/daemon/launchd.go +++ b/internal/daemon/launchd.go @@ -8,6 +8,7 @@ import ( "path/filepath" "github.com/github/git-bundle-server/internal/common" + "github.com/github/git-bundle-server/internal/log" "github.com/github/git-bundle-server/internal/utils" ) @@ -91,17 +92,20 @@ func (c *launchdConfig) toPlist() *plist { } type launchd struct { + logger log.TraceLogger user common.UserProvider cmdExec common.CommandExecutor fileSystem common.FileSystem } func NewLaunchdProvider( + l log.TraceLogger, u common.UserProvider, c common.CommandExecutor, fs common.FileSystem, ) DaemonProvider { return &launchd{ + logger: l, user: u, cmdExec: c, fileSystem: fs, diff --git a/internal/daemon/launchd_test.go b/internal/daemon/launchd_test.go index 1ef6703..c1bf41d 100644 --- a/internal/daemon/launchd_test.go +++ b/internal/daemon/launchd_test.go @@ -240,7 +240,7 @@ func TestLaunchd_Create(t *testing.T) { ctx := context.Background() - launchd := daemon.NewLaunchdProvider(testUserProvider, testCommandExecutor, testFileSystem) + launchd := daemon.NewLaunchdProvider(nil, testUserProvider, testCommandExecutor, testFileSystem) // Verify launchd commands called for _, tt := range launchdCreateBehaviorTests { @@ -367,7 +367,7 @@ func TestLaunchd_Start(t *testing.T) { ctx := context.Background() - launchd := daemon.NewLaunchdProvider(testUserProvider, testCommandExecutor, nil) + launchd := daemon.NewLaunchdProvider(nil, testUserProvider, testCommandExecutor, nil) // Test #1: launchctl succeeds t.Run("Calls correct launchctl command", func(t *testing.T) { @@ -448,7 +448,7 @@ func TestLaunchd_Stop(t *testing.T) { ctx := context.Background() - launchd := daemon.NewLaunchdProvider(testUserProvider, testCommandExecutor, nil) + launchd := daemon.NewLaunchdProvider(nil, testUserProvider, testCommandExecutor, nil) for _, tt := range launchdStopTests { t.Run(tt.title, func(t *testing.T) { @@ -540,7 +540,7 @@ func TestLaunchd_Remove(t *testing.T) { ctx := context.Background() - launchd := daemon.NewLaunchdProvider(testUserProvider, testCommandExecutor, testFileSystem) + launchd := daemon.NewLaunchdProvider(nil, testUserProvider, testCommandExecutor, testFileSystem) for _, tt := range launchdRemoveTests { t.Run(tt.title, func(t *testing.T) { diff --git a/internal/daemon/systemd.go b/internal/daemon/systemd.go index 0d22271..e81586a 100644 --- a/internal/daemon/systemd.go +++ b/internal/daemon/systemd.go @@ -9,6 +9,7 @@ import ( "text/template" "github.com/github/git-bundle-server/internal/common" + "github.com/github/git-bundle-server/internal/log" ) const serviceTemplate string = `[Unit] @@ -22,17 +23,20 @@ ExecStart={{sq_escape .Program}}{{range .Arguments}} {{sq_escape .}}{{end}} const SystemdUnitNotInstalledErrorCode int = 5 type systemd struct { + logger log.TraceLogger user common.UserProvider cmdExec common.CommandExecutor fileSystem common.FileSystem } func NewSystemdProvider( + l log.TraceLogger, u common.UserProvider, c common.CommandExecutor, fs common.FileSystem, ) DaemonProvider { return &systemd{ + logger: l, user: u, cmdExec: c, fileSystem: fs, diff --git a/internal/daemon/systemd_test.go b/internal/daemon/systemd_test.go index 10fe837..f02c945 100644 --- a/internal/daemon/systemd_test.go +++ b/internal/daemon/systemd_test.go @@ -140,7 +140,7 @@ func TestSystemd_Create(t *testing.T) { ctx := context.Background() - systemd := daemon.NewSystemdProvider(testUserProvider, testCommandExecutor, testFileSystem) + systemd := daemon.NewSystemdProvider(nil, testUserProvider, testCommandExecutor, testFileSystem) for _, tt := range systemdCreateBehaviorTests { forceArg := tt.force.ToBoolList() @@ -248,7 +248,7 @@ func TestSystemd_Start(t *testing.T) { ctx := context.Background() - systemd := daemon.NewSystemdProvider(testUserProvider, testCommandExecutor, nil) + systemd := daemon.NewSystemdProvider(nil, testUserProvider, testCommandExecutor, nil) // Test #1: systemctl succeeds t.Run("Calls correct systemctl command", func(t *testing.T) { @@ -292,7 +292,7 @@ func TestSystemd_Stop(t *testing.T) { ctx := context.Background() - systemd := daemon.NewSystemdProvider(testUserProvider, testCommandExecutor, nil) + systemd := daemon.NewSystemdProvider(nil, testUserProvider, testCommandExecutor, nil) // Test #1: systemctl succeeds t.Run("Calls correct systemctl command", func(t *testing.T) { @@ -388,7 +388,7 @@ func TestSystemd_Remove(t *testing.T) { ctx := context.Background() - systemd := daemon.NewSystemdProvider(testUserProvider, testCommandExecutor, testFileSystem) + systemd := daemon.NewSystemdProvider(nil, testUserProvider, testCommandExecutor, testFileSystem) for _, tt := range systemdRemoveTests { t.Run(tt.title, func(t *testing.T) { From 112099379f21db4334ca73769c96e2672d0b5c5c Mon Sep 17 00:00:00 2001 From: Victoria Dye Date: Fri, 24 Feb 2023 08:48:15 -0800 Subject: [PATCH 07/10] testhelpers: create MockTraceLogger Create a 'MockTraceLogger' struct that, like other mocks in 'mock.go', embeds a 'mock.Mock' struct. However, unique to the 'TraceLogger' instance in tests (namely the daemon tests) is that we usually don't want to assert on the logger's behavior; in most cases, we just want it to "do nothing" or pass through. Implement the 'TraceLogger' interface with checks that only invoke 'Mock.Called()' if the method has been mocked at least once (even if it's for args not matching the current ones). If the method is not mocked, fall back on "passthrough" defaults (usually returning the input 'context.Context' and/or the input args). Finally, add an input assertion to 'MockTraceLogger.Error()', to serve as an additional validation in unit tests and mirror the behavior of an actual 'TraceLogger' instance. Signed-off-by: Victoria Dye --- internal/daemon/launchd_test.go | 12 ++-- internal/daemon/systemd_test.go | 12 ++-- internal/testhelpers/mocks.go | 102 ++++++++++++++++++++++++++++++++ 3 files changed, 118 insertions(+), 8 deletions(-) diff --git a/internal/daemon/launchd_test.go b/internal/daemon/launchd_test.go index c1bf41d..b190dfc 100644 --- a/internal/daemon/launchd_test.go +++ b/internal/daemon/launchd_test.go @@ -226,6 +226,7 @@ var launchdCreatePlistTests = []struct { func TestLaunchd_Create(t *testing.T) { // Set up mocks + testLogger := &MockTraceLogger{} testUser := &user.User{ Uid: "123", Username: "testuser", @@ -240,7 +241,7 @@ func TestLaunchd_Create(t *testing.T) { ctx := context.Background() - launchd := daemon.NewLaunchdProvider(nil, testUserProvider, testCommandExecutor, testFileSystem) + launchd := daemon.NewLaunchdProvider(testLogger, testUserProvider, testCommandExecutor, testFileSystem) // Verify launchd commands called for _, tt := range launchdCreateBehaviorTests { @@ -356,6 +357,7 @@ func TestLaunchd_Create(t *testing.T) { func TestLaunchd_Start(t *testing.T) { // Set up mocks + testLogger := &MockTraceLogger{} testUser := &user.User{ Uid: "123", Username: "testuser", @@ -367,7 +369,7 @@ func TestLaunchd_Start(t *testing.T) { ctx := context.Background() - launchd := daemon.NewLaunchdProvider(nil, testUserProvider, testCommandExecutor, nil) + launchd := daemon.NewLaunchdProvider(testLogger, testUserProvider, testCommandExecutor, nil) // Test #1: launchctl succeeds t.Run("Calls correct launchctl command", func(t *testing.T) { @@ -437,6 +439,7 @@ var launchdStopTests = []struct { func TestLaunchd_Stop(t *testing.T) { // Set up mocks + testLogger := &MockTraceLogger{} testUser := &user.User{ Uid: "123", Username: "testuser", @@ -448,7 +451,7 @@ func TestLaunchd_Stop(t *testing.T) { ctx := context.Background() - launchd := daemon.NewLaunchdProvider(nil, testUserProvider, testCommandExecutor, nil) + launchd := daemon.NewLaunchdProvider(testLogger, testUserProvider, testCommandExecutor, nil) for _, tt := range launchdStopTests { t.Run(tt.title, func(t *testing.T) { @@ -527,6 +530,7 @@ var launchdRemoveTests = []struct { func TestLaunchd_Remove(t *testing.T) { // Set up mocks + testLogger := &MockTraceLogger{} testUser := &user.User{ Uid: "123", Username: "testuser", @@ -540,7 +544,7 @@ func TestLaunchd_Remove(t *testing.T) { ctx := context.Background() - launchd := daemon.NewLaunchdProvider(nil, testUserProvider, testCommandExecutor, testFileSystem) + launchd := daemon.NewLaunchdProvider(testLogger, testUserProvider, testCommandExecutor, testFileSystem) for _, tt := range launchdRemoveTests { t.Run(tt.title, func(t *testing.T) { diff --git a/internal/daemon/systemd_test.go b/internal/daemon/systemd_test.go index f02c945..e41923c 100644 --- a/internal/daemon/systemd_test.go +++ b/internal/daemon/systemd_test.go @@ -126,6 +126,7 @@ var systemdCreateServiceUnitTests = []struct { func TestSystemd_Create(t *testing.T) { // Set up mocks + testLogger := &MockTraceLogger{} testUser := &user.User{ Uid: "123", Username: "testuser", @@ -140,7 +141,7 @@ func TestSystemd_Create(t *testing.T) { ctx := context.Background() - systemd := daemon.NewSystemdProvider(nil, testUserProvider, testCommandExecutor, testFileSystem) + systemd := daemon.NewSystemdProvider(testLogger, testUserProvider, testCommandExecutor, testFileSystem) for _, tt := range systemdCreateBehaviorTests { forceArg := tt.force.ToBoolList() @@ -236,6 +237,7 @@ func TestSystemd_Create(t *testing.T) { func TestSystemd_Start(t *testing.T) { // Set up mocks + testLogger := &MockTraceLogger{} testUser := &user.User{ Uid: "123", Username: "testuser", @@ -248,7 +250,7 @@ func TestSystemd_Start(t *testing.T) { ctx := context.Background() - systemd := daemon.NewSystemdProvider(nil, testUserProvider, testCommandExecutor, nil) + systemd := daemon.NewSystemdProvider(testLogger, testUserProvider, testCommandExecutor, nil) // Test #1: systemctl succeeds t.Run("Calls correct systemctl command", func(t *testing.T) { @@ -280,6 +282,7 @@ func TestSystemd_Start(t *testing.T) { func TestSystemd_Stop(t *testing.T) { // Set up mocks + testLogger := &MockTraceLogger{} testUser := &user.User{ Uid: "123", Username: "testuser", @@ -292,7 +295,7 @@ func TestSystemd_Stop(t *testing.T) { ctx := context.Background() - systemd := daemon.NewSystemdProvider(nil, testUserProvider, testCommandExecutor, nil) + systemd := daemon.NewSystemdProvider(testLogger, testUserProvider, testCommandExecutor, nil) // Test #1: systemctl succeeds t.Run("Calls correct systemctl command", func(t *testing.T) { @@ -375,6 +378,7 @@ var systemdRemoveTests = []struct { func TestSystemd_Remove(t *testing.T) { // Set up mocks + testLogger := &MockTraceLogger{} testUser := &user.User{ Uid: "123", Username: "testuser", @@ -388,7 +392,7 @@ func TestSystemd_Remove(t *testing.T) { ctx := context.Background() - systemd := daemon.NewSystemdProvider(nil, testUserProvider, testCommandExecutor, testFileSystem) + systemd := daemon.NewSystemdProvider(testLogger, testUserProvider, testCommandExecutor, testFileSystem) for _, tt := range systemdRemoveTests { t.Run(tt.title, func(t *testing.T) { diff --git a/internal/testhelpers/mocks.go b/internal/testhelpers/mocks.go index 9f8dc90..6481b7e 100644 --- a/internal/testhelpers/mocks.go +++ b/internal/testhelpers/mocks.go @@ -1,11 +1,113 @@ package testhelpers import ( + "context" + "fmt" "os/user" + "runtime" "github.com/stretchr/testify/mock" ) +func methodIsMocked(m *mock.Mock) bool { + // Get the calling method name + pc := make([]uintptr, 1) + n := runtime.Callers(1, pc) + if n == 0 { + // No caller found - fall back on "not mocked" + return false + } + caller := runtime.FuncForPC(pc[0] - 1) + if caller == nil { + // Caller not found - fall back on "not mocked" + return false + } + + for _, call := range m.ExpectedCalls { + if call.Method == caller.Name() { + return true + } + } + + return false +} + +type notMocked struct{} + +var NotMockedValue notMocked = notMocked{} + +func mockWithDefault[T any](args mock.Arguments, index int, defaultValue T) T { + if len(args) <= index { + return defaultValue + } + + mockedValue := args.Get(index) + if _, ok := mockedValue.(notMocked); ok { + return defaultValue + } + + return mockedValue.(T) +} + +type MockTraceLogger struct { + mock.Mock +} + +func (l *MockTraceLogger) Region(ctx context.Context, category string, label string) (context.Context, func()) { + fnArgs := mock.Arguments{} + if methodIsMocked(&l.Mock) { + fnArgs = l.Called(ctx, category, label) + } + return mockWithDefault(fnArgs, 0, ctx), mockWithDefault(fnArgs, 1, func() {}) +} + +func (l *MockTraceLogger) LogCommand(ctx context.Context, commandName string) context.Context { + fnArgs := mock.Arguments{} + if methodIsMocked(&l.Mock) { + fnArgs = l.Called(ctx, commandName) + } + return mockWithDefault(fnArgs, 0, ctx) +} + +func (l *MockTraceLogger) Error(ctx context.Context, err error) error { + // Input validation + if err == nil { + panic("err must be nil") + } + + fnArgs := mock.Arguments{} + if methodIsMocked(&l.Mock) { + fnArgs = l.Called(ctx, err) + } + return mockWithDefault(fnArgs, 0, err) +} + +func (l *MockTraceLogger) Errorf(ctx context.Context, format string, a ...any) error { + fnArgs := mock.Arguments{} + if methodIsMocked(&l.Mock) { + fnArgs = l.Called(ctx, format, a) + } + return mockWithDefault(fnArgs, 0, fmt.Errorf(format, a...)) +} + +func (l *MockTraceLogger) Exit(ctx context.Context, exitCode int) { + if methodIsMocked(&l.Mock) { + l.Called(ctx, exitCode) + } +} + +func (l *MockTraceLogger) Fatal(ctx context.Context, err error) { + if methodIsMocked(&l.Mock) { + l.Called(ctx, err) + } +} + +func (l *MockTraceLogger) Fatalf(ctx context.Context, format string, a ...any) { + if methodIsMocked(&l.Mock) { + l.Called(ctx, format, a) + } +} + type MockUserProvider struct { mock.Mock } From 8a22d6e31e7b3dc7cdf7acf64316be477afb511b Mon Sep 17 00:00:00 2001 From: Victoria Dye Date: Thu, 23 Feb 2023 14:27:53 -0800 Subject: [PATCH 08/10] log: capture exit conditions Replace executions of 'os.Exit()', 'log.Fatal()', and 'log.Fatalf()' with 'TraceLogger.Exit()', 'TraceLogger.Fatal()', and 'TraceLogger.Fatalf()', respectively. When these are called, log the appropriate 'exit'/'atexit' events. Note that 'TraceLogger.Fatal()' and 'TraceLogger.Fatalf()' currently do not log the error (or error format + args) they receive. This will be updated in a later patch, but the unused arguments are added now to avoid needing to remove then later add back the arguments passed to 'log.Fatal()' and 'log.Fatalf()' by their callers. Signed-off-by: Victoria Dye --- cmd/git-bundle-server/main.go | 9 ++++----- cmd/git-bundle-web-server/bundle-server.go | 9 ++++----- internal/argparse/argparse.go | 5 ++--- internal/log/logger.go | 6 +++++- internal/log/trace2.go | 13 +++++++++++++ 5 files changed, 28 insertions(+), 14 deletions(-) diff --git a/cmd/git-bundle-server/main.go b/cmd/git-bundle-server/main.go index 0654a9e..ba6bbf2 100644 --- a/cmd/git-bundle-server/main.go +++ b/cmd/git-bundle-server/main.go @@ -2,14 +2,13 @@ package main import ( "context" - "log" "os" "github.com/github/git-bundle-server/internal/argparse" - tracelog "github.com/github/git-bundle-server/internal/log" + "github.com/github/git-bundle-server/internal/log" ) -func all(logger tracelog.TraceLogger) []argparse.Subcommand { +func all(logger log.TraceLogger) []argparse.Subcommand { return []argparse.Subcommand{ NewDeleteCommand(logger), NewInitCommand(logger), @@ -22,7 +21,7 @@ func all(logger tracelog.TraceLogger) []argparse.Subcommand { } func main() { - tracelog.WithTraceLogger(context.Background(), func(ctx context.Context, logger tracelog.TraceLogger) { + log.WithTraceLogger(context.Background(), func(ctx context.Context, logger log.TraceLogger) { cmds := all(logger) parser := argparse.NewArgParser(logger, "git-bundle-server []") @@ -34,7 +33,7 @@ func main() { err := parser.InvokeSubcommand(ctx) if err != nil { - log.Fatalf("Failed with error: %s", err) + logger.Fatalf(ctx, "Failed with error: %s", err) } }) } diff --git a/cmd/git-bundle-web-server/bundle-server.go b/cmd/git-bundle-web-server/bundle-server.go index db923ae..78cb50d 100644 --- a/cmd/git-bundle-web-server/bundle-server.go +++ b/cmd/git-bundle-web-server/bundle-server.go @@ -3,7 +3,6 @@ package main import ( "context" "fmt" - "log" "net/http" "os" "os/signal" @@ -13,17 +12,17 @@ import ( "github.com/github/git-bundle-server/internal/common" "github.com/github/git-bundle-server/internal/core" - tracelog "github.com/github/git-bundle-server/internal/log" + "github.com/github/git-bundle-server/internal/log" ) type bundleWebServer struct { - logger tracelog.TraceLogger + logger log.TraceLogger server *http.Server serverWaitGroup *sync.WaitGroup listenAndServeFunc func() error } -func NewBundleWebServer(logger tracelog.TraceLogger, +func NewBundleWebServer(logger log.TraceLogger, port string, certFile string, keyFile string, ) *bundleWebServer { bundleServer := &bundleWebServer{ @@ -123,7 +122,7 @@ func (b *bundleWebServer) StartServerAsync(ctx context.Context) { // Return error unless it indicates graceful shutdown err := b.listenAndServeFunc() if err != nil && err != http.ErrServerClosed { - log.Fatal(err) + b.logger.Fatal(ctx, err) } }(ctx) diff --git a/internal/argparse/argparse.go b/internal/argparse/argparse.go index aae5fe5..184a964 100644 --- a/internal/argparse/argparse.go +++ b/internal/argparse/argparse.go @@ -4,7 +4,6 @@ import ( "context" "flag" "fmt" - "os" "strings" "github.com/github/git-bundle-server/internal/log" @@ -146,7 +145,7 @@ func (a *argParser) Parse(ctx context.Context, args []string) { if err != nil { // The error was already printed (via a.FlagSet.Usage()), so we // just need to exit - os.Exit(usageExitCode) + a.logger.Exit(ctx, usageExitCode) } if len(a.subcommands) > 0 { @@ -222,5 +221,5 @@ func (a *argParser) Usage(ctx context.Context, errFmt string, args ...any) { fmt.Fprintf(a.FlagSet.Output(), errFmt+"\n", args...) a.FlagSet.Usage() - os.Exit(usageExitCode) + a.logger.Exit(ctx, usageExitCode) } diff --git a/internal/log/logger.go b/internal/log/logger.go index cf062ea..f2b57a1 100644 --- a/internal/log/logger.go +++ b/internal/log/logger.go @@ -7,7 +7,11 @@ import ( "runtime/debug" ) -type TraceLogger interface{} +type TraceLogger interface { + Exit(ctx context.Context, exitCode int) + Fatal(ctx context.Context, err error) + Fatalf(ctx context.Context, format string, a ...any) +} type traceLoggerInternal interface { // Internal setup/teardown functions diff --git a/internal/log/trace2.go b/internal/log/trace2.go index ddad24f..e28561c 100644 --- a/internal/log/trace2.go +++ b/internal/log/trace2.go @@ -165,3 +165,16 @@ func (t *Trace2) logExit(ctx context.Context, exitCode int) { t.logger.Sync() } + +func (t *Trace2) Exit(ctx context.Context, exitCode int) { + t.logExit(ctx, exitCode) + os.Exit(exitCode) +} + +func (t *Trace2) Fatal(ctx context.Context, err error) { + t.Exit(ctx, 1) +} + +func (t *Trace2) Fatalf(ctx context.Context, format string, a ...any) { + t.Exit(ctx, 1) +} From c9f4ef5730567b12a172b8e0863f4e6f005b20e4 Mon Sep 17 00:00:00 2001 From: Victoria Dye Date: Thu, 23 Feb 2023 13:20:29 -0800 Subject: [PATCH 09/10] log: add error logging Add logging for errors returned 'fmt.Errorf()' errors as well as the errors passed to or created by 'TraceLogger.Fatal()'/'TraceLogger.Fatalf()'. Replace all use of 'fmt.Errorf()' or direct error returns with 'TraceLogger.Errorf()' or 'TraceLogger.Error()', which return 'error' instances to stay consistent with what they're replacing. These errors are logged under the "error" trace2 event. We want to capture the error as low in the call stack as possible (to more precisely indicate the cause), but also not log a new "error" event for every 'fmt.Errorf()' replaced. To do that, have the error wrappers return a custom type 'loggedError', and only log the error event if the 'err' or none of the 'Errorf()' arguments are of the 'loggedError' type. Signed-off-by: Victoria Dye --- cmd/git-bundle-server/delete.go | 8 +++--- cmd/git-bundle-server/init.go | 14 ++++----- cmd/git-bundle-server/start.go | 5 ++-- cmd/git-bundle-server/stop.go | 7 ++++- cmd/git-bundle-server/update-all.go | 11 ++++---- cmd/git-bundle-server/update.go | 10 +++---- cmd/git-bundle-server/web-server.go | 34 +++++++++++----------- internal/argparse/argparse.go | 2 ++ internal/daemon/launchd.go | 44 ++++++++++++++--------------- internal/daemon/systemd.go | 28 +++++++++--------- internal/log/logger.go | 6 ++++ internal/log/trace2.go | 34 ++++++++++++++++++++++ 12 files changed, 124 insertions(+), 79 deletions(-) diff --git a/cmd/git-bundle-server/delete.go b/cmd/git-bundle-server/delete.go index 270b11d..5833341 100644 --- a/cmd/git-bundle-server/delete.go +++ b/cmd/git-bundle-server/delete.go @@ -36,22 +36,22 @@ func (d *deleteCmd) Run(ctx context.Context, args []string) error { repo, err := core.CreateRepository(*route) if err != nil { - return err + return d.logger.Error(ctx, err) } err = core.RemoveRoute(*route) if err != nil { - return err + return d.logger.Error(ctx, err) } err = os.RemoveAll(repo.WebDir) if err != nil { - return err + return d.logger.Error(ctx, err) } err = os.RemoveAll(repo.RepoDir) if err != nil { - return err + return d.logger.Error(ctx, err) } return nil diff --git a/cmd/git-bundle-server/init.go b/cmd/git-bundle-server/init.go index bd3c773..2a1d3bd 100644 --- a/cmd/git-bundle-server/init.go +++ b/cmd/git-bundle-server/init.go @@ -40,24 +40,24 @@ func (i *initCmd) Run(ctx context.Context, args []string) error { repo, err := core.CreateRepository(*route) if err != nil { - return err + return i.logger.Error(ctx, err) } fmt.Printf("Cloning repository from %s\n", *url) gitErr := git.GitCommand("clone", "--bare", *url, repo.RepoDir) if gitErr != nil { - return fmt.Errorf("failed to clone repository: %w", gitErr) + return i.logger.Errorf(ctx, "failed to clone repository: %w", gitErr) } gitErr = git.GitCommand("-C", repo.RepoDir, "config", "remote.origin.fetch", "+refs/heads/*:refs/heads/*") if gitErr != nil { - return fmt.Errorf("failed to configure refspec: %w", gitErr) + return i.logger.Errorf(ctx, "failed to configure refspec: %w", gitErr) } gitErr = git.GitCommand("-C", repo.RepoDir, "fetch", "origin") if gitErr != nil { - return fmt.Errorf("failed to fetch latest refs: %w", gitErr) + return i.logger.Errorf(ctx, "failed to fetch latest refs: %w", gitErr) } bundle := bundles.CreateInitialBundle(repo) @@ -65,16 +65,16 @@ func (i *initCmd) Run(ctx context.Context, args []string) error { written, gitErr := git.CreateBundle(repo.RepoDir, bundle.Filename) if gitErr != nil { - return fmt.Errorf("failed to create bundle: %w", gitErr) + return i.logger.Errorf(ctx, "failed to create bundle: %w", gitErr) } if !written { - return fmt.Errorf("refused to write empty bundle. Is the repo empty?") + return i.logger.Errorf(ctx, "refused to write empty bundle. Is the repo empty?") } list := bundles.CreateSingletonList(bundle) listErr := bundles.WriteBundleList(list, repo) if listErr != nil { - return fmt.Errorf("failed to write bundle list: %w", listErr) + return i.logger.Errorf(ctx, "failed to write bundle list: %w", listErr) } SetCronSchedule() diff --git a/cmd/git-bundle-server/start.go b/cmd/git-bundle-server/start.go index b26f244..8dea0e8 100644 --- a/cmd/git-bundle-server/start.go +++ b/cmd/git-bundle-server/start.go @@ -2,7 +2,6 @@ package main import ( "context" - "fmt" "os" "github.com/github/git-bundle-server/internal/argparse" @@ -38,12 +37,12 @@ func (s *startCmd) Run(ctx context.Context, args []string) error { // CreateRepository registers the route. repo, err := core.CreateRepository(*route) if err != nil { - return err + return s.logger.Error(ctx, err) } _, err = os.ReadDir(repo.RepoDir) if err != nil { - return fmt.Errorf("route '%s' appears to have been deleted; use 'init' instead", *route) + return s.logger.Errorf(ctx, "route '%s' appears to have been deleted; use 'init' instead", *route) } // Make sure we have the global schedule running. diff --git a/cmd/git-bundle-server/stop.go b/cmd/git-bundle-server/stop.go index c051dcc..3fa591b 100644 --- a/cmd/git-bundle-server/stop.go +++ b/cmd/git-bundle-server/stop.go @@ -33,5 +33,10 @@ func (s *stopCmd) Run(ctx context.Context, args []string) error { route := parser.PositionalString("route", "the route for which bundles should stop being generated") parser.Parse(ctx, args) - return core.RemoveRoute(*route) + err := core.RemoveRoute(*route) + if err != nil { + s.logger.Error(ctx, err) + } + + return nil } diff --git a/cmd/git-bundle-server/update-all.go b/cmd/git-bundle-server/update-all.go index 1322a01..432c39b 100644 --- a/cmd/git-bundle-server/update-all.go +++ b/cmd/git-bundle-server/update-all.go @@ -2,7 +2,6 @@ package main import ( "context" - "fmt" "os" "os/exec" @@ -34,7 +33,7 @@ For every configured route, run 'git-bundle-server update '.` func (u *updateAllCmd) Run(ctx context.Context, args []string) error { user, err := common.NewUserProvider().CurrentUser() if err != nil { - return err + return u.logger.Error(ctx, err) } fs := common.NewFileSystem() @@ -43,12 +42,12 @@ func (u *updateAllCmd) Run(ctx context.Context, args []string) error { exe, err := os.Executable() if err != nil { - return fmt.Errorf("failed to get path to execuable: %w", err) + return u.logger.Errorf(ctx, "failed to get path to execuable: %w", err) } repos, err := core.GetRepositories(user, fs) if err != nil { - return err + return u.logger.Error(ctx, err) } subargs := []string{"update", ""} @@ -62,12 +61,12 @@ func (u *updateAllCmd) Run(ctx context.Context, args []string) error { err := cmd.Start() if err != nil { - return fmt.Errorf("git command failed to start: %w", err) + return u.logger.Errorf(ctx, "git command failed to start: %w", err) } err = cmd.Wait() if err != nil { - return fmt.Errorf("git command returned a failure: %w", err) + return u.logger.Errorf(ctx, "git command returned a failure: %w", err) } } diff --git a/cmd/git-bundle-server/update.go b/cmd/git-bundle-server/update.go index acc9709..13b0060 100644 --- a/cmd/git-bundle-server/update.go +++ b/cmd/git-bundle-server/update.go @@ -38,18 +38,18 @@ func (u *updateCmd) Run(ctx context.Context, args []string) error { repo, err := core.CreateRepository(*route) if err != nil { - return err + return u.logger.Error(ctx, err) } list, err := bundles.GetBundleList(repo) if err != nil { - return fmt.Errorf("failed to load bundle list: %w", err) + return u.logger.Errorf(ctx, "failed to load bundle list: %w", err) } fmt.Printf("Creating new incremental bundle\n") bundle, err := bundles.CreateIncrementalBundle(repo, list) if err != nil { - return err + return u.logger.Error(ctx, err) } // Nothing new! @@ -62,13 +62,13 @@ func (u *updateCmd) Run(ctx context.Context, args []string) error { fmt.Printf("Collapsing bundle list\n") err = bundles.CollapseList(repo, list) if err != nil { - return err + return u.logger.Error(ctx, err) } fmt.Printf("Writing updated bundle list\n") listErr := bundles.WriteBundleList(list, repo) if listErr != nil { - return fmt.Errorf("failed to write bundle list: %w", listErr) + return u.logger.Errorf(ctx, "failed to write bundle list: %w", listErr) } return nil diff --git a/cmd/git-bundle-server/web-server.go b/cmd/git-bundle-server/web-server.go index 677782c..912089a 100644 --- a/cmd/git-bundle-server/web-server.go +++ b/cmd/git-bundle-server/web-server.go @@ -41,7 +41,7 @@ func (webServerCmd) Description() string { return `Manage the web server hosting bundle content` } -func (w *webServerCmd) getDaemonConfig() (*daemon.DaemonConfig, error) { +func (w *webServerCmd) getDaemonConfig(ctx context.Context) (*daemon.DaemonConfig, error) { // Find git-bundle-web-server // First, search for it on the path programPath, err := exec.LookPath("git-bundle-web-server") @@ -50,25 +50,25 @@ func (w *webServerCmd) getDaemonConfig() (*daemon.DaemonConfig, error) { // Result is a relative path programPath, err = filepath.Abs(programPath) if err != nil { - return nil, fmt.Errorf("could not get absolute path to program: %w", err) + return nil, w.logger.Errorf(ctx, "could not get absolute path to program: %w", err) } } else { // Fall back on looking for it in the same directory as the currently-running executable exePath, err := os.Executable() if err != nil { - return nil, fmt.Errorf("failed to get path to current executable: %w", err) + return nil, w.logger.Errorf(ctx, "failed to get path to current executable: %w", err) } exeDir := filepath.Dir(exePath) if err != nil { - return nil, fmt.Errorf("failed to get parent dir of current executable: %w", err) + return nil, w.logger.Errorf(ctx, "failed to get parent dir of current executable: %w", err) } programPath = filepath.Join(exeDir, "git-bundle-web-server") programExists, err := w.fileSystem.FileExists(programPath) if err != nil { - return nil, fmt.Errorf("could not determine whether path to 'git-bundle-web-server' exists: %w", err) + return nil, w.logger.Errorf(ctx, "could not determine whether path to 'git-bundle-web-server' exists: %w", err) } else if !programExists { - return nil, fmt.Errorf("could not find path to 'git-bundle-web-server'") + return nil, w.logger.Errorf(ctx, "could not find path to 'git-bundle-web-server'") } } } @@ -99,12 +99,12 @@ func (w *webServerCmd) startServer(ctx context.Context, args []string) error { d, err := daemon.NewDaemonProvider(w.logger, w.user, w.cmdExec, w.fileSystem) if err != nil { - return err + return w.logger.Error(ctx, err) } - config, err := w.getDaemonConfig() + config, err := w.getDaemonConfig(ctx) if err != nil { - return err + return w.logger.Error(ctx, err) } // Configure flags @@ -131,17 +131,17 @@ func (w *webServerCmd) startServer(ctx context.Context, args []string) error { }) if loopErr != nil { // Error happened in 'Visit' - return loopErr + return w.logger.Error(ctx, loopErr) } err = d.Create(ctx, config, *force) if err != nil { - return err + return w.logger.Error(ctx, err) } err = d.Start(ctx, config.Label) if err != nil { - return err + return w.logger.Error(ctx, err) } return nil @@ -155,23 +155,23 @@ func (w *webServerCmd) stopServer(ctx context.Context, args []string) error { d, err := daemon.NewDaemonProvider(w.logger, w.user, w.cmdExec, w.fileSystem) if err != nil { - return err + return w.logger.Error(ctx, err) } - config, err := w.getDaemonConfig() + config, err := w.getDaemonConfig(ctx) if err != nil { - return err + return w.logger.Error(ctx, err) } err = d.Stop(ctx, config.Label) if err != nil { - return err + return w.logger.Error(ctx, err) } if *remove { err = d.Remove(ctx, config.Label) if err != nil { - return err + return w.logger.Error(ctx, err) } } diff --git a/internal/argparse/argparse.go b/internal/argparse/argparse.go index 184a964..006e531 100644 --- a/internal/argparse/argparse.go +++ b/internal/argparse/argparse.go @@ -145,6 +145,7 @@ func (a *argParser) Parse(ctx context.Context, args []string) { if err != nil { // The error was already printed (via a.FlagSet.Usage()), so we // just need to exit + a.logger.Error(ctx, err) a.logger.Exit(ctx, usageExitCode) } @@ -221,5 +222,6 @@ func (a *argParser) Usage(ctx context.Context, errFmt string, args ...any) { fmt.Fprintf(a.FlagSet.Output(), errFmt+"\n", args...) a.FlagSet.Usage() + a.logger.Errorf(ctx, errFmt, args...) a.logger.Exit(ctx, usageExitCode) } diff --git a/internal/daemon/launchd.go b/internal/daemon/launchd.go index df87d6f..30e307a 100644 --- a/internal/daemon/launchd.go +++ b/internal/daemon/launchd.go @@ -116,7 +116,7 @@ func (l *launchd) isBootstrapped(ctx context.Context, serviceTarget string) (boo // run 'launchctl print' on given service target to see if it exists exitCode, err := l.cmdExec.Run("launchctl", "print", serviceTarget) if err != nil { - return false, err + return false, l.logger.Error(ctx, err) } if exitCode == 0 { @@ -124,7 +124,7 @@ func (l *launchd) isBootstrapped(ctx context.Context, serviceTarget string) (boo } else if exitCode == LaunchdServiceNotFoundErrorCode { return false, nil } else { - return false, fmt.Errorf("could not determine if service '%s' is bootstrapped: "+ + return false, l.logger.Errorf(ctx, "could not determine if service '%s' is bootstrapped: "+ "'launchctl print' exited with status '%d'", serviceTarget, exitCode) } } @@ -133,11 +133,11 @@ func (l *launchd) bootstrapFile(ctx context.Context, domain string, filename str // run 'launchctl bootstrap' on given domain & file exitCode, err := l.cmdExec.Run("launchctl", "bootstrap", domain, filename) if err != nil { - return err + return l.logger.Error(ctx, err) } if exitCode != 0 { - return fmt.Errorf("'launchctl bootstrap' exited with status %d", exitCode) + return l.logger.Errorf(ctx, "'launchctl bootstrap' exited with status %d", exitCode) } return nil @@ -147,7 +147,7 @@ func (l *launchd) bootout(ctx context.Context, serviceTarget string) (bool, erro // run 'launchctl bootout' on given service target exitCode, err := l.cmdExec.Run("launchctl", "bootout", serviceTarget) if err != nil { - return false, err + return false, l.logger.Error(ctx, err) } if exitCode == 0 { @@ -155,7 +155,7 @@ func (l *launchd) bootout(ctx context.Context, serviceTarget string) (bool, erro } else if exitCode == LaunchdNoSuchProcessErrorCode { return false, nil } else { - return false, fmt.Errorf("'launchctl bootout' failed with status %d", exitCode) + return false, l.logger.Errorf(ctx, "'launchctl bootout' failed with status %d", exitCode) } } @@ -176,13 +176,13 @@ func (l *launchd) Create(ctx context.Context, config *DaemonConfig, force bool) encoder.Indent("", " ") err := encoder.Encode(lConfig.toPlist()) if err != nil { - return fmt.Errorf("could not encode plist: %w", err) + return l.logger.Errorf(ctx, "could not encode plist: %w", err) } // Check the existing file - if it's the same as the new content, do not overwrite user, err := l.user.CurrentUser() if err != nil { - return fmt.Errorf("could not get current user for launchd service: %w", err) + return l.logger.Errorf(ctx, "could not get current user for launchd service: %w", err) } filename := filepath.Join(user.HomeDir, "Library", "LaunchAgents", fmt.Sprintf("%s.plist", config.Label)) @@ -191,12 +191,12 @@ func (l *launchd) Create(ctx context.Context, config *DaemonConfig, force bool) alreadyLoaded, err := l.isBootstrapped(ctx, serviceTarget) if err != nil { - return err + return l.logger.Error(ctx, err) } fileExists, err := l.fileSystem.FileExists(filename) if err != nil { - return fmt.Errorf("could not determine whether plist '%s' exists: %w", filename, err) + return l.logger.Errorf(ctx, "could not determine whether plist '%s' exists: %w", filename, err) } // If not forcing re-configuration & the service configuration is valid, @@ -209,7 +209,7 @@ func (l *launchd) Create(ctx context.Context, config *DaemonConfig, force bool) if alreadyLoaded { _, err = l.bootout(ctx, serviceTarget) if err != nil { - return fmt.Errorf("could not bootout daemon process '%s': %w", config.Label, err) + return l.logger.Errorf(ctx, "could not bootout daemon process '%s': %w", config.Label, err) } } @@ -218,13 +218,13 @@ func (l *launchd) Create(ctx context.Context, config *DaemonConfig, force bool) // TODO: only overwrite file if file contents have changed err = l.fileSystem.WriteFile(filename, newPlist.Bytes()) if err != nil { - return fmt.Errorf("unable to write plist file: %w", err) + return l.logger.Errorf(ctx, "unable to write plist file: %w", err) } } err = l.bootstrapFile(ctx, domainTarget, filename) if err != nil { - return fmt.Errorf("could not bootstrap daemon process '%s': %w", config.Label, err) + return l.logger.Errorf(ctx, "could not bootstrap daemon process '%s': %w", config.Label, err) } return nil @@ -233,18 +233,18 @@ func (l *launchd) Create(ctx context.Context, config *DaemonConfig, force bool) func (l *launchd) Start(ctx context.Context, label string) error { user, err := l.user.CurrentUser() if err != nil { - return fmt.Errorf("could not get current user for launchd service: %w", err) + return l.logger.Errorf(ctx, "could not get current user for launchd service: %w", err) } domainTarget := fmt.Sprintf(domainFormat, user.Uid) serviceTarget := fmt.Sprintf("%s/%s", domainTarget, label) exitCode, err := l.cmdExec.Run("launchctl", "kickstart", serviceTarget) if err != nil { - return err + return l.logger.Error(ctx, err) } if exitCode != 0 { - return fmt.Errorf("'launchctl kickstart' exited with status %d", exitCode) + return l.logger.Errorf(ctx, "'launchctl kickstart' exited with status %d", exitCode) } return nil @@ -253,21 +253,21 @@ func (l *launchd) Start(ctx context.Context, label string) error { func (l *launchd) Stop(ctx context.Context, label string) error { user, err := l.user.CurrentUser() if err != nil { - return fmt.Errorf("could not get current user for launchd service: %w", err) + return l.logger.Errorf(ctx, "could not get current user for launchd service: %w", err) } domainTarget := fmt.Sprintf(domainFormat, user.Uid) serviceTarget := fmt.Sprintf("%s/%s", domainTarget, label) exitCode, err := l.cmdExec.Run("launchctl", "kill", "SIGINT", serviceTarget) if err != nil { - return err + return l.logger.Error(ctx, err) } // Don't throw an error if the service hasn't been bootstrapped if exitCode != 0 && exitCode != LaunchdServiceNotFoundErrorCode && exitCode != LaunchdNoSuchProcessErrorCode { - return fmt.Errorf("'launchctl kill' exited with status %d", exitCode) + return l.logger.Errorf(ctx, "'launchctl kill' exited with status %d", exitCode) } return nil @@ -276,7 +276,7 @@ func (l *launchd) Stop(ctx context.Context, label string) error { func (l *launchd) Remove(ctx context.Context, label string) error { user, err := l.user.CurrentUser() if err != nil { - return fmt.Errorf("could not get current user for launchd service: %w", err) + return l.logger.Errorf(ctx, "could not get current user for launchd service: %w", err) } filename := filepath.Join(user.HomeDir, "Library", "LaunchAgents", fmt.Sprintf("%s.plist", label)) @@ -285,12 +285,12 @@ func (l *launchd) Remove(ctx context.Context, label string) error { _, err = l.bootout(ctx, serviceTarget) if err != nil { - return fmt.Errorf("could not remove daemon process '%s': %w", label, err) + return l.logger.Errorf(ctx, "could not remove daemon process '%s': %w", label, err) } _, err = l.fileSystem.DeleteFile(filename) if err != nil { - return fmt.Errorf("could not delete launchd plist: %w", err) + return l.logger.Errorf(ctx, "could not delete launchd plist: %w", err) } return nil diff --git a/internal/daemon/systemd.go b/internal/daemon/systemd.go index e81586a..6e97465 100644 --- a/internal/daemon/systemd.go +++ b/internal/daemon/systemd.go @@ -46,11 +46,11 @@ func NewSystemdProvider( func (s *systemd) reloadDaemon(ctx context.Context) error { exitCode, err := s.cmdExec.Run("systemctl", "--user", "daemon-reload") if err != nil { - return err + return s.logger.Error(ctx, err) } if exitCode != 0 { - return fmt.Errorf("'systemctl --user daemon-reload' exited with status %d", exitCode) + return s.logger.Errorf(ctx, "'systemctl --user daemon-reload' exited with status %d", exitCode) } return nil @@ -59,7 +59,7 @@ func (s *systemd) reloadDaemon(ctx context.Context) error { func (s *systemd) Create(ctx context.Context, config *DaemonConfig, force bool) error { user, err := s.user.CurrentUser() if err != nil { - return fmt.Errorf("could not get current user for systemd service: %w", err) + return s.logger.Errorf(ctx, "could not get current user for systemd service: %w", err) } // Generate the configuration @@ -70,7 +70,7 @@ func (s *systemd) Create(ctx context.Context, config *DaemonConfig, force bool) }, }).Parse(serviceTemplate) if err != nil { - return fmt.Errorf("unable to generate systemd configuration: %w", err) + return s.logger.Errorf(ctx, "unable to generate systemd configuration: %w", err) } t.Execute(&newServiceUnit, config) @@ -79,7 +79,7 @@ func (s *systemd) Create(ctx context.Context, config *DaemonConfig, force bool) // Check whether the file exists fileExists, err := s.fileSystem.FileExists(filename) if err != nil { - return fmt.Errorf("could not determine whether service unit '%s' exists: %w", config.Label, err) + return s.logger.Errorf(ctx, "could not determine whether service unit '%s' exists: %w", config.Label, err) } if !force && fileExists { @@ -90,13 +90,13 @@ func (s *systemd) Create(ctx context.Context, config *DaemonConfig, force bool) // Otherwise, write the new file err = s.fileSystem.WriteFile(filename, newServiceUnit.Bytes()) if err != nil { - return fmt.Errorf("unable to write service unit: %w", err) + return s.logger.Errorf(ctx, "unable to write service unit: %w", err) } // Reload the user-scoped service units after adding err = s.reloadDaemon(ctx) if err != nil { - return err + return s.logger.Error(ctx, err) } return nil @@ -106,11 +106,11 @@ func (s *systemd) Start(ctx context.Context, label string) error { // TODO: warn user if already running exitCode, err := s.cmdExec.Run("systemctl", "--user", "start", label) if err != nil { - return err + return s.logger.Error(ctx, err) } if exitCode != 0 { - return fmt.Errorf("'systemctl stop' exited with status %d", exitCode) + return s.logger.Errorf(ctx, "'systemctl stop' exited with status %d", exitCode) } return nil @@ -120,11 +120,11 @@ func (s *systemd) Stop(ctx context.Context, label string) error { // TODO: warn user if already stopped exitCode, err := s.cmdExec.Run("systemctl", "--user", "stop", label) if err != nil { - return err + return s.logger.Error(ctx, err) } if exitCode != 0 && exitCode != SystemdUnitNotInstalledErrorCode { - return fmt.Errorf("'systemctl stop' exited with status %d", exitCode) + return s.logger.Errorf(ctx, "'systemctl stop' exited with status %d", exitCode) } return nil @@ -133,19 +133,19 @@ func (s *systemd) Stop(ctx context.Context, label string) error { func (s *systemd) Remove(ctx context.Context, label string) error { user, err := s.user.CurrentUser() if err != nil { - return fmt.Errorf("could not get current user for launchd service: %w", err) + return s.logger.Errorf(ctx, "could not get current user for launchd service: %w", err) } filename := filepath.Join(user.HomeDir, ".config", "systemd", "user", fmt.Sprintf("%s.service", label)) _, err = s.fileSystem.DeleteFile(filename) if err != nil { - return fmt.Errorf("could not delete service unit: %w", err) + return s.logger.Errorf(ctx, "could not delete service unit: %w", err) } // Reload the user-scoped service units after removing err = s.reloadDaemon(ctx) if err != nil { - return err + return s.logger.Error(ctx, err) } return nil diff --git a/internal/log/logger.go b/internal/log/logger.go index f2b57a1..368ad64 100644 --- a/internal/log/logger.go +++ b/internal/log/logger.go @@ -7,7 +7,13 @@ import ( "runtime/debug" ) +// Type alias used to keep track of whether an error was already logged deeper +// in the call stack. +type loggedError error + type TraceLogger interface { + Error(ctx context.Context, err error) error + Errorf(ctx context.Context, format string, a ...any) error Exit(ctx context.Context, exitCode int) Fatal(ctx context.Context, err error) Fatalf(ctx context.Context, format string, a ...any) diff --git a/internal/log/trace2.go b/internal/log/trace2.go index e28561c..4e1eb90 100644 --- a/internal/log/trace2.go +++ b/internal/log/trace2.go @@ -166,6 +166,40 @@ func (t *Trace2) logExit(ctx context.Context, exitCode int) { t.logger.Sync() } +func (t *Trace2) Error(ctx context.Context, err error) error { + // We only want to log the error if it's not already logged deeper in the + // call stack. + if _, ok := err.(loggedError); !ok { + _, sharedFields := t.sharedFields(ctx) + t.logger.Error("error", sharedFields.with( + zap.String("msg", err.Error()), + zap.String("fmt", err.Error()))...) + } + return loggedError(err) +} + +func (t *Trace2) Errorf(ctx context.Context, format string, a ...any) error { + // We only want to log the error if it's not already logged deeper in the + // call stack. + isLogged := false + for _, fmtArg := range a { + if _, ok := fmtArg.(loggedError); ok { + isLogged = true + break + } + } + + err := loggedError(fmt.Errorf(format, a...)) + + if isLogged { + _, sharedFields := t.sharedFields(ctx) + t.logger.Info("error", sharedFields.with( + zap.String("msg", err.Error()), + zap.String("fmt", format))...) + } + return err +} + func (t *Trace2) Exit(ctx context.Context, exitCode int) { t.logExit(ctx, exitCode) os.Exit(exitCode) From 3504fd28844282838107d7236a235a27fc20ce4c Mon Sep 17 00:00:00 2001 From: Victoria Dye Date: Wed, 22 Feb 2023 14:56:42 -0800 Subject: [PATCH 10/10] log: add command logging Add a 'LogCommand()' function to the 'TraceLogger', called in 'argparser.InvokeSubcommand()' right before the command is invoked. To remain consistent with Git, do this only when a toplevel command (indicated by 'isTopLevel') is invoked; nothing is done for subcommands of those commands. In the trace2 logger, this corresponds to the "cmd_name" event. Signed-off-by: Victoria Dye --- internal/argparse/argparse.go | 4 ++++ internal/log/logger.go | 1 + internal/log/trace2.go | 8 ++++++++ 3 files changed, 13 insertions(+) diff --git a/internal/argparse/argparse.go b/internal/argparse/argparse.go index 006e531..9fe8f2f 100644 --- a/internal/argparse/argparse.go +++ b/internal/argparse/argparse.go @@ -215,6 +215,10 @@ func (a *argParser) InvokeSubcommand(ctx context.Context) error { panic("subcommand has not been parsed") } + if a.isTopLevel { + a.logger.LogCommand(ctx, a.selectedSubcommand.Name()) + } + return a.selectedSubcommand.Run(ctx, a.Args()) } diff --git a/internal/log/logger.go b/internal/log/logger.go index 368ad64..41ff3c2 100644 --- a/internal/log/logger.go +++ b/internal/log/logger.go @@ -12,6 +12,7 @@ import ( type loggedError error type TraceLogger interface { + LogCommand(ctx context.Context, commandName string) context.Context Error(ctx context.Context, err error) error Errorf(ctx context.Context, format string, a ...any) error Exit(ctx context.Context, exitCode int) diff --git a/internal/log/trace2.go b/internal/log/trace2.go index 4e1eb90..789dd55 100644 --- a/internal/log/trace2.go +++ b/internal/log/trace2.go @@ -166,6 +166,14 @@ func (t *Trace2) logExit(ctx context.Context, exitCode int) { t.logger.Sync() } +func (t *Trace2) LogCommand(ctx context.Context, commandName string) context.Context { + ctx, sharedFields := t.sharedFields(ctx) + + t.logger.Info("cmd_name", sharedFields.with(zap.String("name", commandName))...) + + return ctx +} + func (t *Trace2) Error(ctx context.Context, err error) error { // We only want to log the error if it's not already logged deeper in the // call stack.