diff --git a/cmd/git-bundle-server/delete.go b/cmd/git-bundle-server/delete.go index 6233667..5833341 100644 --- a/cmd/git-bundle-server/delete.go +++ b/cmd/git-bundle-server/delete.go @@ -1,47 +1,57 @@ package main import ( + "context" "os" "github.com/github/git-bundle-server/internal/argparse" "github.com/github/git-bundle-server/internal/core" + "github.com/github/git-bundle-server/internal/log" ) -type Delete struct{} +type deleteCmd struct { + logger log.TraceLogger +} + +func NewDeleteCommand(logger log.TraceLogger) argparse.Subcommand { + return &deleteCmd{ + logger: logger, + } +} -func (Delete) Name() string { +func (deleteCmd) Name() string { return "delete" } -func (Delete) Description() string { +func (deleteCmd) Description() string { return ` Remove the configuration for the given '' and delete its repository data.` } -func (Delete) Run(args []string) error { - parser := argparse.NewArgParser("git-bundle-server delete ") +func (d *deleteCmd) Run(ctx context.Context, args []string) error { + parser := argparse.NewArgParser(d.logger, "git-bundle-server delete ") route := parser.PositionalString("route", "the route to delete") - parser.Parse(args) + parser.Parse(ctx, args) repo, err := core.CreateRepository(*route) if err != nil { - return err + return d.logger.Error(ctx, err) } err = core.RemoveRoute(*route) if err != nil { - return err + return d.logger.Error(ctx, err) } err = os.RemoveAll(repo.WebDir) if err != nil { - return err + return d.logger.Error(ctx, err) } err = os.RemoveAll(repo.RepoDir) if err != nil { - return err + return d.logger.Error(ctx, err) } return nil diff --git a/cmd/git-bundle-server/init.go b/cmd/git-bundle-server/init.go index 0daf65e..2a1d3bd 100644 --- a/cmd/git-bundle-server/init.go +++ b/cmd/git-bundle-server/init.go @@ -1,53 +1,63 @@ package main import ( + "context" "fmt" "github.com/github/git-bundle-server/internal/argparse" "github.com/github/git-bundle-server/internal/bundles" "github.com/github/git-bundle-server/internal/core" "github.com/github/git-bundle-server/internal/git" + "github.com/github/git-bundle-server/internal/log" ) -type Init struct{} +type initCmd struct { + logger log.TraceLogger +} + +func NewInitCommand(logger log.TraceLogger) argparse.Subcommand { + return &initCmd{ + logger: logger, + } +} -func (Init) Name() string { +func (initCmd) Name() string { return "init" } -func (Init) Description() string { +func (initCmd) Description() string { return ` Initialize a repository by cloning a bare repo from '', whose bundles should be hosted at ''.` } -func (Init) Run(args []string) error { - parser := argparse.NewArgParser("git-bundle-server init ") +func (i *initCmd) Run(ctx context.Context, args []string) error { + parser := argparse.NewArgParser(i.logger, "git-bundle-server init ") url := parser.PositionalString("url", "the URL of a repository to clone") // TODO: allow parsing out of route := parser.PositionalString("route", "the route to host the specified repo") - parser.Parse(args) + parser.Parse(ctx, args) repo, err := core.CreateRepository(*route) if err != nil { - return err + return i.logger.Error(ctx, err) } fmt.Printf("Cloning repository from %s\n", *url) gitErr := git.GitCommand("clone", "--bare", *url, repo.RepoDir) if gitErr != nil { - return fmt.Errorf("failed to clone repository: %w", gitErr) + return i.logger.Errorf(ctx, "failed to clone repository: %w", gitErr) } gitErr = git.GitCommand("-C", repo.RepoDir, "config", "remote.origin.fetch", "+refs/heads/*:refs/heads/*") if gitErr != nil { - return fmt.Errorf("failed to configure refspec: %w", gitErr) + return i.logger.Errorf(ctx, "failed to configure refspec: %w", gitErr) } gitErr = git.GitCommand("-C", repo.RepoDir, "fetch", "origin") if gitErr != nil { - return fmt.Errorf("failed to fetch latest refs: %w", gitErr) + return i.logger.Errorf(ctx, "failed to fetch latest refs: %w", gitErr) } bundle := bundles.CreateInitialBundle(repo) @@ -55,16 +65,16 @@ func (Init) Run(args []string) error { written, gitErr := git.CreateBundle(repo.RepoDir, bundle.Filename) if gitErr != nil { - return fmt.Errorf("failed to create bundle: %w", gitErr) + return i.logger.Errorf(ctx, "failed to create bundle: %w", gitErr) } if !written { - return fmt.Errorf("refused to write empty bundle. Is the repo empty?") + return i.logger.Errorf(ctx, "refused to write empty bundle. Is the repo empty?") } list := bundles.CreateSingletonList(bundle) listErr := bundles.WriteBundleList(list, repo) if listErr != nil { - return fmt.Errorf("failed to write bundle list: %w", listErr) + return i.logger.Errorf(ctx, "failed to write bundle list: %w", listErr) } SetCronSchedule() diff --git a/cmd/git-bundle-server/main.go b/cmd/git-bundle-server/main.go index 5ed5b51..ba6bbf2 100644 --- a/cmd/git-bundle-server/main.go +++ b/cmd/git-bundle-server/main.go @@ -1,36 +1,39 @@ package main import ( - "log" + "context" "os" "github.com/github/git-bundle-server/internal/argparse" + "github.com/github/git-bundle-server/internal/log" ) -func all() []argparse.Subcommand { +func all(logger log.TraceLogger) []argparse.Subcommand { return []argparse.Subcommand{ - Delete{}, - Init{}, - Start{}, - Stop{}, - Update{}, - UpdateAll{}, - NewWebServerCommand(), + NewDeleteCommand(logger), + NewInitCommand(logger), + NewStartCommand(logger), + NewStopCommand(logger), + NewUpdateCommand(logger), + NewUpdateAllCommand(logger), + NewWebServerCommand(logger), } } func main() { - cmds := all() + log.WithTraceLogger(context.Background(), func(ctx context.Context, logger log.TraceLogger) { + cmds := all(logger) - parser := argparse.NewArgParser("git-bundle-server []") - parser.SetIsTopLevel(true) - for _, cmd := range cmds { - parser.Subcommand(cmd) - } - parser.Parse(os.Args[1:]) + parser := argparse.NewArgParser(logger, "git-bundle-server []") + parser.SetIsTopLevel(true) + for _, cmd := range cmds { + parser.Subcommand(cmd) + } + parser.Parse(ctx, os.Args[1:]) - err := parser.InvokeSubcommand() - if err != nil { - log.Fatal("Failed with error: ", err) - } + err := parser.InvokeSubcommand(ctx) + if err != nil { + logger.Fatalf(ctx, "Failed with error: %s", err) + } + }) } diff --git a/cmd/git-bundle-server/start.go b/cmd/git-bundle-server/start.go index 3c6270e..8dea0e8 100644 --- a/cmd/git-bundle-server/start.go +++ b/cmd/git-bundle-server/start.go @@ -1,39 +1,48 @@ package main import ( - "fmt" + "context" "os" "github.com/github/git-bundle-server/internal/argparse" "github.com/github/git-bundle-server/internal/core" + "github.com/github/git-bundle-server/internal/log" ) -type Start struct{} +type startCmd struct { + logger log.TraceLogger +} + +func NewStartCommand(logger log.TraceLogger) argparse.Subcommand { + return &startCmd{ + logger: logger, + } +} -func (Start) Name() string { +func (startCmd) Name() string { return "start" } -func (Start) Description() string { +func (startCmd) Description() string { return ` Start computing bundles and serving content for the repository at the specified ''.` } -func (Start) Run(args []string) error { - parser := argparse.NewArgParser("git-bundle-server start ") +func (s *startCmd) Run(ctx context.Context, args []string) error { + parser := argparse.NewArgParser(s.logger, "git-bundle-server start ") route := parser.PositionalString("route", "the route for which bundles should be generated") - parser.Parse(args) + parser.Parse(ctx, args) // CreateRepository registers the route. repo, err := core.CreateRepository(*route) if err != nil { - return err + return s.logger.Error(ctx, err) } _, err = os.ReadDir(repo.RepoDir) if err != nil { - return fmt.Errorf("route '%s' appears to have been deleted; use 'init' instead", *route) + return s.logger.Errorf(ctx, "route '%s' appears to have been deleted; use 'init' instead", *route) } // Make sure we have the global schedule running. diff --git a/cmd/git-bundle-server/stop.go b/cmd/git-bundle-server/stop.go index dec74ac..3fa591b 100644 --- a/cmd/git-bundle-server/stop.go +++ b/cmd/git-bundle-server/stop.go @@ -1,26 +1,42 @@ package main import ( + "context" + "github.com/github/git-bundle-server/internal/argparse" "github.com/github/git-bundle-server/internal/core" + "github.com/github/git-bundle-server/internal/log" ) -type Stop struct{} +type stopCmd struct { + logger log.TraceLogger +} -func (Stop) Name() string { +func NewStopCommand(logger log.TraceLogger) argparse.Subcommand { + return &stopCmd{ + logger: logger, + } +} + +func (stopCmd) Name() string { return "stop" } -func (Stop) Description() string { +func (stopCmd) Description() string { return ` Stop computing bundles or serving content for the repository at the specified ''.` } -func (Stop) Run(args []string) error { - parser := argparse.NewArgParser("git-bundle-server stop ") +func (s *stopCmd) Run(ctx context.Context, args []string) error { + parser := argparse.NewArgParser(s.logger, "git-bundle-server stop ") route := parser.PositionalString("route", "the route for which bundles should stop being generated") - parser.Parse(args) + parser.Parse(ctx, args) + + err := core.RemoveRoute(*route) + if err != nil { + s.logger.Error(ctx, err) + } - return core.RemoveRoute(*route) + return nil } diff --git a/cmd/git-bundle-server/update-all.go b/cmd/git-bundle-server/update-all.go index 671dac4..432c39b 100644 --- a/cmd/git-bundle-server/update-all.go +++ b/cmd/git-bundle-server/update-all.go @@ -1,44 +1,53 @@ package main import ( - "fmt" + "context" "os" "os/exec" "github.com/github/git-bundle-server/internal/argparse" "github.com/github/git-bundle-server/internal/common" "github.com/github/git-bundle-server/internal/core" + "github.com/github/git-bundle-server/internal/log" ) -type UpdateAll struct{} +type updateAllCmd struct { + logger log.TraceLogger +} + +func NewUpdateAllCommand(logger log.TraceLogger) argparse.Subcommand { + return &updateAllCmd{ + logger: logger, + } +} -func (UpdateAll) Name() string { +func (updateAllCmd) Name() string { return "update-all" } -func (UpdateAll) Description() string { +func (updateAllCmd) Description() string { return ` For every configured route, run 'git-bundle-server update '.` } -func (UpdateAll) Run(args []string) error { +func (u *updateAllCmd) Run(ctx context.Context, args []string) error { user, err := common.NewUserProvider().CurrentUser() if err != nil { - return err + return u.logger.Error(ctx, err) } fs := common.NewFileSystem() - parser := argparse.NewArgParser("git-bundle-server update-all") - parser.Parse(args) + parser := argparse.NewArgParser(u.logger, "git-bundle-server update-all") + parser.Parse(ctx, args) exe, err := os.Executable() if err != nil { - return fmt.Errorf("failed to get path to execuable: %w", err) + return u.logger.Errorf(ctx, "failed to get path to execuable: %w", err) } repos, err := core.GetRepositories(user, fs) if err != nil { - return err + return u.logger.Error(ctx, err) } subargs := []string{"update", ""} @@ -52,12 +61,12 @@ func (UpdateAll) Run(args []string) error { err := cmd.Start() if err != nil { - return fmt.Errorf("git command failed to start: %w", err) + return u.logger.Errorf(ctx, "git command failed to start: %w", err) } err = cmd.Wait() if err != nil { - return fmt.Errorf("git command returned a failure: %w", err) + return u.logger.Errorf(ctx, "git command returned a failure: %w", err) } } diff --git a/cmd/git-bundle-server/update.go b/cmd/git-bundle-server/update.go index 514f610..13b0060 100644 --- a/cmd/git-bundle-server/update.go +++ b/cmd/git-bundle-server/update.go @@ -1,45 +1,55 @@ package main import ( + "context" "fmt" "github.com/github/git-bundle-server/internal/argparse" "github.com/github/git-bundle-server/internal/bundles" "github.com/github/git-bundle-server/internal/core" + "github.com/github/git-bundle-server/internal/log" ) -type Update struct{} +type updateCmd struct { + logger log.TraceLogger +} + +func NewUpdateCommand(logger log.TraceLogger) argparse.Subcommand { + return &updateCmd{ + logger: logger, + } +} -func (Update) Name() string { +func (updateCmd) Name() string { return "update" } -func (Update) Description() string { +func (updateCmd) Description() string { return ` For the repository in the current directory (or the one specified by ''), fetch the latest content from the remote, create a new set of bundles, and update the bundle list.` } -func (Update) Run(args []string) error { - parser := argparse.NewArgParser("git-bundle-server update ") +func (u *updateCmd) Run(ctx context.Context, args []string) error { + parser := argparse.NewArgParser(u.logger, "git-bundle-server update ") route := parser.PositionalString("route", "the route to update") - parser.Parse(args) + parser.Parse(ctx, args) repo, err := core.CreateRepository(*route) if err != nil { - return err + return u.logger.Error(ctx, err) } list, err := bundles.GetBundleList(repo) if err != nil { - return fmt.Errorf("failed to load bundle list: %w", err) + return u.logger.Errorf(ctx, "failed to load bundle list: %w", err) } fmt.Printf("Creating new incremental bundle\n") bundle, err := bundles.CreateIncrementalBundle(repo, list) if err != nil { - return err + return u.logger.Error(ctx, err) } // Nothing new! @@ -52,13 +62,13 @@ func (Update) Run(args []string) error { fmt.Printf("Collapsing bundle list\n") err = bundles.CollapseList(repo, list) if err != nil { - return err + return u.logger.Error(ctx, err) } fmt.Printf("Writing updated bundle list\n") listErr := bundles.WriteBundleList(list, repo) if listErr != nil { - return fmt.Errorf("failed to write bundle list: %w", listErr) + return u.logger.Errorf(ctx, "failed to write bundle list: %w", listErr) } return nil diff --git a/cmd/git-bundle-server/web-server.go b/cmd/git-bundle-server/web-server.go index c79aa6f..912089a 100644 --- a/cmd/git-bundle-server/web-server.go +++ b/cmd/git-bundle-server/web-server.go @@ -1,6 +1,7 @@ package main import ( + "context" "errors" "flag" "fmt" @@ -12,32 +13,35 @@ import ( "github.com/github/git-bundle-server/internal/argparse" "github.com/github/git-bundle-server/internal/common" "github.com/github/git-bundle-server/internal/daemon" + "github.com/github/git-bundle-server/internal/log" ) -type webServer struct { +type webServerCmd struct { + logger log.TraceLogger user common.UserProvider cmdExec common.CommandExecutor fileSystem common.FileSystem } -func NewWebServerCommand() *webServer { - // Create dependencies - return &webServer{ +func NewWebServerCommand(logger log.TraceLogger) argparse.Subcommand { + // Create subcommand-specific dependencies + return &webServerCmd{ + logger: logger, user: common.NewUserProvider(), cmdExec: common.NewCommandExecutor(), fileSystem: common.NewFileSystem(), } } -func (webServer) Name() string { +func (webServerCmd) Name() string { return "web-server" } -func (webServer) Description() string { +func (webServerCmd) Description() string { return `Manage the web server hosting bundle content` } -func (w *webServer) getDaemonConfig() (*daemon.DaemonConfig, error) { +func (w *webServerCmd) getDaemonConfig(ctx context.Context) (*daemon.DaemonConfig, error) { // Find git-bundle-web-server // First, search for it on the path programPath, err := exec.LookPath("git-bundle-web-server") @@ -46,25 +50,25 @@ func (w *webServer) getDaemonConfig() (*daemon.DaemonConfig, error) { // Result is a relative path programPath, err = filepath.Abs(programPath) if err != nil { - return nil, fmt.Errorf("could not get absolute path to program: %w", err) + return nil, w.logger.Errorf(ctx, "could not get absolute path to program: %w", err) } } else { // Fall back on looking for it in the same directory as the currently-running executable exePath, err := os.Executable() if err != nil { - return nil, fmt.Errorf("failed to get path to current executable: %w", err) + return nil, w.logger.Errorf(ctx, "failed to get path to current executable: %w", err) } exeDir := filepath.Dir(exePath) if err != nil { - return nil, fmt.Errorf("failed to get parent dir of current executable: %w", err) + return nil, w.logger.Errorf(ctx, "failed to get parent dir of current executable: %w", err) } programPath = filepath.Join(exeDir, "git-bundle-web-server") programExists, err := w.fileSystem.FileExists(programPath) if err != nil { - return nil, fmt.Errorf("could not determine whether path to 'git-bundle-web-server' exists: %w", err) + return nil, w.logger.Errorf(ctx, "could not determine whether path to 'git-bundle-web-server' exists: %w", err) } else if !programExists { - return nil, fmt.Errorf("could not find path to 'git-bundle-web-server'") + return nil, w.logger.Errorf(ctx, "could not find path to 'git-bundle-web-server'") } } } @@ -76,9 +80,9 @@ func (w *webServer) getDaemonConfig() (*daemon.DaemonConfig, error) { }, nil } -func (w *webServer) startServer(args []string) error { +func (w *webServerCmd) startServer(ctx context.Context, args []string) error { // Parse subcommand arguments - parser := argparse.NewArgParser("git-bundle-server web-server start [-f|--force]") + parser := argparse.NewArgParser(w.logger, "git-bundle-server web-server start [-f|--force]") // Args for 'git-bundle-server web-server start' force := parser.Bool("force", false, "Force reconfiguration of the web server daemon") @@ -90,17 +94,17 @@ func (w *webServer) startServer(args []string) error { parser.Var(f.Value, f.Name, fmt.Sprintf("[Web server] %s", f.Usage)) }) - parser.Parse(args) - validate() + parser.Parse(ctx, args) + validate(ctx) - d, err := daemon.NewDaemonProvider(w.user, w.cmdExec, w.fileSystem) + d, err := daemon.NewDaemonProvider(w.logger, w.user, w.cmdExec, w.fileSystem) if err != nil { - return err + return w.logger.Error(ctx, err) } - config, err := w.getDaemonConfig() + config, err := w.getDaemonConfig(ctx) if err != nil { - return err + return w.logger.Error(ctx, err) } // Configure flags @@ -127,59 +131,59 @@ func (w *webServer) startServer(args []string) error { }) if loopErr != nil { // Error happened in 'Visit' - return loopErr + return w.logger.Error(ctx, loopErr) } - err = d.Create(config, *force) + err = d.Create(ctx, config, *force) if err != nil { - return err + return w.logger.Error(ctx, err) } - err = d.Start(config.Label) + err = d.Start(ctx, config.Label) if err != nil { - return err + return w.logger.Error(ctx, err) } return nil } -func (w *webServer) stopServer(args []string) error { +func (w *webServerCmd) stopServer(ctx context.Context, args []string) error { // Parse subcommand arguments - parser := argparse.NewArgParser("git-bundle-server web-server stop [--remove]") + parser := argparse.NewArgParser(w.logger, "git-bundle-server web-server stop [--remove]") remove := parser.Bool("remove", false, "Remove the web server daemon configuration from the system after stopping") - parser.Parse(args) + parser.Parse(ctx, args) - d, err := daemon.NewDaemonProvider(w.user, w.cmdExec, w.fileSystem) + d, err := daemon.NewDaemonProvider(w.logger, w.user, w.cmdExec, w.fileSystem) if err != nil { - return err + return w.logger.Error(ctx, err) } - config, err := w.getDaemonConfig() + config, err := w.getDaemonConfig(ctx) if err != nil { - return err + return w.logger.Error(ctx, err) } - err = d.Stop(config.Label) + err = d.Stop(ctx, config.Label) if err != nil { - return err + return w.logger.Error(ctx, err) } if *remove { - err = d.Remove(config.Label) + err = d.Remove(ctx, config.Label) if err != nil { - return err + return w.logger.Error(ctx, err) } } return nil } -func (w *webServer) Run(args []string) error { +func (w *webServerCmd) Run(ctx context.Context, args []string) error { // Parse command arguments - parser := argparse.NewArgParser("git-bundle-server web-server (start|stop) ") + parser := argparse.NewArgParser(w.logger, "git-bundle-server web-server (start|stop) ") parser.Subcommand(argparse.NewSubcommand("start", "Start the web server", w.startServer)) parser.Subcommand(argparse.NewSubcommand("stop", "Stop the web server", w.stopServer)) - parser.Parse(args) + parser.Parse(ctx, args) - return parser.InvokeSubcommand() + return parser.InvokeSubcommand(ctx) } diff --git a/cmd/git-bundle-web-server/bundle-server.go b/cmd/git-bundle-web-server/bundle-server.go new file mode 100644 index 0000000..78cb50d --- /dev/null +++ b/cmd/git-bundle-web-server/bundle-server.go @@ -0,0 +1,145 @@ +package main + +import ( + "context" + "fmt" + "net/http" + "os" + "os/signal" + "strings" + "sync" + "syscall" + + "github.com/github/git-bundle-server/internal/common" + "github.com/github/git-bundle-server/internal/core" + "github.com/github/git-bundle-server/internal/log" +) + +type bundleWebServer struct { + logger log.TraceLogger + server *http.Server + serverWaitGroup *sync.WaitGroup + listenAndServeFunc func() error +} + +func NewBundleWebServer(logger log.TraceLogger, + port string, certFile string, keyFile string, +) *bundleWebServer { + bundleServer := &bundleWebServer{ + logger: logger, + serverWaitGroup: &sync.WaitGroup{}, + } + + // Configure the http.Server + mux := http.NewServeMux() + mux.HandleFunc("/", bundleServer.serve) + bundleServer.server = &http.Server{ + Handler: mux, + Addr: ":" + port, + } + + if certFile != "" { + bundleServer.listenAndServeFunc = func() error { return bundleServer.server.ListenAndServeTLS(certFile, keyFile) } + } else { + bundleServer.listenAndServeFunc = func() error { return bundleServer.server.ListenAndServe() } + } + + return bundleServer +} + +func (b *bundleWebServer) parseRoute(ctx context.Context, path string) (string, string, string, error) { + elements := strings.FieldsFunc(path, func(char rune) bool { return char == '/' }) + switch len(elements) { + case 0: + return "", "", "", fmt.Errorf("empty route") + case 1: + return "", "", "", fmt.Errorf("route has owner, but no repo") + case 2: + return elements[0], elements[1], "", nil + case 3: + return elements[0], elements[1], elements[2], nil + default: + return "", "", "", fmt.Errorf("path has depth exceeding three") + } +} + +func (b *bundleWebServer) serve(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + user, err := common.NewUserProvider().CurrentUser() + if err != nil { + return + } + fs := common.NewFileSystem() + path := r.URL.Path + + owner, repo, file, err := b.parseRoute(ctx, path) + if err != nil { + w.WriteHeader(http.StatusNotFound) + fmt.Printf("Failed to parse route: %s\n", err) + return + } + + route := owner + "/" + repo + + repos, err := core.GetRepositories(user, fs) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Printf("Failed to load routes\n") + return + } + + repository, contains := repos[route] + if !contains { + w.WriteHeader(http.StatusNotFound) + fmt.Printf("Failed to get route out of repos\n") + return + } + + if file == "" { + file = "bundle-list" + } + + fileToServe := repository.WebDir + "/" + file + data, err := os.ReadFile(fileToServe) + if err != nil { + w.WriteHeader(http.StatusNotFound) + fmt.Printf("Failed to read file\n") + return + } + + fmt.Printf("Successfully serving content for %s/%s\n", route, file) + w.Write(data) +} + +func (b *bundleWebServer) StartServerAsync(ctx context.Context) { + // Add to wait group + b.serverWaitGroup.Add(1) + + go func(ctx context.Context) { + defer b.serverWaitGroup.Done() + + // Return error unless it indicates graceful shutdown + err := b.listenAndServeFunc() + if err != nil && err != http.ErrServerClosed { + b.logger.Fatal(ctx, err) + } + }(ctx) + + fmt.Println("Server is running at address " + b.server.Addr) +} + +func (b *bundleWebServer) HandleSignalsAsync(ctx context.Context) { + // Intercept interrupt signals + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + go func(ctx context.Context) { + <-c + fmt.Println("Starting graceful server shutdown...") + b.server.Shutdown(ctx) + }(ctx) +} + +func (b *bundleWebServer) Wait() { + b.serverWaitGroup.Wait() +} diff --git a/cmd/git-bundle-web-server/main.go b/cmd/git-bundle-web-server/main.go index b52ff2d..be4bd78 100644 --- a/cmd/git-bundle-web-server/main.go +++ b/cmd/git-bundle-web-server/main.go @@ -4,146 +4,41 @@ import ( "context" "flag" "fmt" - "log" - "net/http" "os" - "os/signal" - "strings" - "sync" - "syscall" "github.com/github/git-bundle-server/cmd/utils" "github.com/github/git-bundle-server/internal/argparse" - "github.com/github/git-bundle-server/internal/common" - "github.com/github/git-bundle-server/internal/core" + "github.com/github/git-bundle-server/internal/log" ) -func parseRoute(path string) (string, string, string, error) { - elements := strings.FieldsFunc(path, func(char rune) bool { return char == '/' }) - switch len(elements) { - case 0: - return "", "", "", fmt.Errorf("empty route") - case 1: - return "", "", "", fmt.Errorf("route has owner, but no repo") - case 2: - return elements[0], elements[1], "", nil - case 3: - return elements[0], elements[1], elements[2], nil - default: - return "", "", "", fmt.Errorf("path has depth exceeding three") - } -} - -func serve(w http.ResponseWriter, r *http.Request) { - user, err := common.NewUserProvider().CurrentUser() - if err != nil { - return - } - fs := common.NewFileSystem() - path := r.URL.Path - - owner, repo, file, err := parseRoute(path) - if err != nil { - w.WriteHeader(http.StatusNotFound) - fmt.Printf("Failed to parse route: %s\n", err) - return - } - - route := owner + "/" + repo - - repos, err := core.GetRepositories(user, fs) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - fmt.Printf("Failed to load routes\n") - return - } - - repository, contains := repos[route] - if !contains { - w.WriteHeader(http.StatusNotFound) - fmt.Printf("Failed to get route out of repos\n") - return - } - - if file == "" { - file = "bundle-list" - } - - fileToServe := repository.WebDir + "/" + file - data, err := os.ReadFile(fileToServe) - if err != nil { - w.WriteHeader(http.StatusNotFound) - fmt.Printf("Failed to read file\n") - return - } +func main() { + log.WithTraceLogger(context.Background(), func(ctx context.Context, logger log.TraceLogger) { + parser := argparse.NewArgParser(logger, "git-bundle-web-server [--port ] [--cert --key ]") + flags, validate := utils.WebServerFlags(parser) + flags.VisitAll(func(f *flag.Flag) { + parser.Var(f.Value, f.Name, f.Usage) + }) - fmt.Printf("Successfully serving content for %s/%s\n", route, file) - w.Write(data) -} + parser.Parse(ctx, os.Args[1:]) + validate(ctx) -func startServer(server *http.Server, - cert string, key string, - serverWaitGroup *sync.WaitGroup, -) { - // Add to wait group - serverWaitGroup.Add(1) + // Get the flag values + port := utils.GetFlagValue[string](parser, "port") + cert := utils.GetFlagValue[string](parser, "cert") + key := utils.GetFlagValue[string](parser, "key") - go func() { - defer serverWaitGroup.Done() + // Configure the server + bundleServer := NewBundleWebServer(logger, port, cert, key) - // Return error unless it indicates graceful shutdown - var err error - if cert != "" { - err = server.ListenAndServeTLS(cert, key) - } else { - err = server.ListenAndServe() - } + // Start the server asynchronously + bundleServer.StartServerAsync(ctx) - if err != nil && err != http.ErrServerClosed { - log.Fatal(err) - } - }() + // Intercept interrupt signals + bundleServer.HandleSignalsAsync(ctx) - fmt.Println("Server is running at address " + server.Addr) -} + // Wait for server to shut down + bundleServer.Wait() -func main() { - parser := argparse.NewArgParser("git-bundle-web-server [--port ] [--cert --key ]") - flags, validate := utils.WebServerFlags(parser) - flags.VisitAll(func(f *flag.Flag) { - parser.Var(f.Value, f.Name, f.Usage) + fmt.Println("Shutdown complete") }) - parser.Parse(os.Args[1:]) - validate() - - // Get the flag values - port := utils.GetFlagValue[string](parser, "port") - cert := utils.GetFlagValue[string](parser, "cert") - key := utils.GetFlagValue[string](parser, "key") - - // Configure the server - mux := http.NewServeMux() - mux.HandleFunc("/", serve) - server := &http.Server{ - Handler: mux, - Addr: ":" + port, - } - serverWaitGroup := &sync.WaitGroup{} - - // Start the server asynchronously - startServer(server, cert, key, serverWaitGroup) - - // Intercept interrupt signals - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) - go func() { - <-c - fmt.Println("Starting graceful server shutdown...") - server.Shutdown(context.Background()) - }() - - // Wait for server to shut down - serverWaitGroup.Wait() - - fmt.Println("Shutdown complete") } diff --git a/cmd/utils/common-args.go b/cmd/utils/common-args.go index 9f6af20..2331ae8 100644 --- a/cmd/utils/common-args.go +++ b/cmd/utils/common-args.go @@ -1,6 +1,7 @@ package utils import ( + "context" "flag" "fmt" "strconv" @@ -14,7 +15,7 @@ import ( // functions we want to call from the parser. type argParser interface { Lookup(name string) *flag.Flag - Usage(errFmt string, args ...any) + Usage(ctx context.Context, errFmt string, args ...any) } func GetFlagValue[T any](parser argParser, name string) T { @@ -38,20 +39,20 @@ func GetFlagValue[T any](parser argParser, name string) T { // Sets of flags shared between multiple commands/programs -func WebServerFlags(parser argParser) (*flag.FlagSet, func()) { +func WebServerFlags(parser argParser) (*flag.FlagSet, func(context.Context)) { f := flag.NewFlagSet("", flag.ContinueOnError) port := f.String("port", "8080", "The port on which the server should be hosted") cert := f.String("cert", "", "The path to the X.509 SSL certificate file to use in securely hosting the server") key := f.String("key", "", "The path to the certificate's private key") // Function to call for additional arg validation (may exit with 'Usage()') - validationFunc := func() { + validationFunc := func(ctx context.Context) { p, err := strconv.Atoi(*port) if err != nil || p < 0 || p > 65535 { - parser.Usage("Invalid port '%s'.", *port) + parser.Usage(ctx, "Invalid port '%s'.", *port) } if (*cert == "") != (*key == "") { - parser.Usage("Both '--cert' and '--key' are needed to specify SSL configuration.") + parser.Usage(ctx, "Both '--cert' and '--key' are needed to specify SSL configuration.") } } diff --git a/go.mod b/go.mod index e0bc362..82bb106 100644 --- a/go.mod +++ b/go.mod @@ -2,11 +2,17 @@ module github.com/github/git-bundle-server go 1.19 -require github.com/stretchr/testify v1.8.1 +require ( + github.com/google/uuid v1.3.0 + github.com/stretchr/testify v1.8.1 + go.uber.org/zap v1.24.0 +) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.5.0 // indirect + go.uber.org/atomic v1.10.0 // indirect + go.uber.org/multierr v1.9.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 231c467..89383a5 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,10 @@ +github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -11,6 +15,13 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= +go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= +go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= +go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60= +go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/argparse/argparse.go b/internal/argparse/argparse.go index 3048469..9fe8f2f 100644 --- a/internal/argparse/argparse.go +++ b/internal/argparse/argparse.go @@ -1,12 +1,17 @@ package argparse import ( + "context" "flag" "fmt" - "os" "strings" + + "github.com/github/git-bundle-server/internal/log" ) +// For consistency with 'flag', use 2 as the usage-related error code +const usageExitCode int = 2 + type positionalArg struct { name string description string @@ -26,17 +31,19 @@ type argParser struct { // Post-parsing selectedSubcommand Subcommand + logger log.TraceLogger flag.FlagSet } -func NewArgParser(usageString string) *argParser { - flagSet := flag.NewFlagSet("", flag.ExitOnError) +func NewArgParser(logger log.TraceLogger, usageString string) *argParser { + flagSet := flag.NewFlagSet("", flag.ContinueOnError) a := &argParser{ isTopLevel: false, parsed: false, argOffset: 0, subcommands: make(map[string]Subcommand), + logger: logger, FlagSet: *flagSet, } @@ -114,7 +121,7 @@ func (a *argParser) PositionalList(name string, description string) *[]string { return arg } -func (a *argParser) Parse(args []string) { +func (a *argParser) Parse(ctx context.Context, args []string) { if a.parsed { // Do nothing if we've already parsed args return @@ -136,18 +143,21 @@ func (a *argParser) Parse(args []string) { err := a.FlagSet.Parse(args) if err != nil { - panic("argParser FlagSet error handling should be 'ExitOnError', but error encountered") + // The error was already printed (via a.FlagSet.Usage()), so we + // just need to exit + a.logger.Error(ctx, err) + a.logger.Exit(ctx, usageExitCode) } if len(a.subcommands) > 0 { // Parse subcommand, if applicable if a.FlagSet.NArg() == 0 { - a.Usage("Please specify a subcommand") + a.Usage(ctx, "Please specify a subcommand") } subcommand, exists := a.subcommands[a.FlagSet.Arg(0)] if !exists { - a.Usage("Invalid subcommand '%s'", a.FlagSet.Arg(0)) + a.Usage(ctx, "Invalid subcommand '%s'", a.FlagSet.Arg(0)) } else { a.selectedSubcommand = subcommand a.argOffset++ @@ -177,7 +187,7 @@ func (a *argParser) Parse(args []string) { if a.NArg() != 0 { // If not using subcommands, all args should be accounted for // Exit with usage if not - a.Usage("Unused arguments specified: %s", strings.Join(a.Args(), " ")) + a.Usage(ctx, "Unused arguments specified: %s", strings.Join(a.Args(), " ")) } } @@ -200,18 +210,22 @@ func (a *argParser) NArg() int { } } -func (a *argParser) InvokeSubcommand() error { +func (a *argParser) InvokeSubcommand(ctx context.Context) error { if !a.parsed || a.selectedSubcommand == nil { panic("subcommand has not been parsed") } - return a.selectedSubcommand.Run(a.Args()) + if a.isTopLevel { + a.logger.LogCommand(ctx, a.selectedSubcommand.Name()) + } + + return a.selectedSubcommand.Run(ctx, a.Args()) } -func (a *argParser) Usage(errFmt string, args ...any) { +func (a *argParser) Usage(ctx context.Context, errFmt string, args ...any) { fmt.Fprintf(a.FlagSet.Output(), errFmt+"\n", args...) a.FlagSet.Usage() - // Exit with error code 2 to match flag.Parse() behavior - os.Exit(2) + a.logger.Errorf(ctx, errFmt, args...) + a.logger.Exit(ctx, usageExitCode) } diff --git a/internal/argparse/subcommand.go b/internal/argparse/subcommand.go index 5e8b038..77f7b8c 100644 --- a/internal/argparse/subcommand.go +++ b/internal/argparse/subcommand.go @@ -1,21 +1,23 @@ package argparse +import "context" + type Subcommand interface { Name() string Description() string - Run(args []string) error + Run(ctx context.Context, args []string) error } type genericSubcommand struct { nameStr string descriptionStr string - runFunc func([]string) error + runFunc func(context.Context, []string) error } func NewSubcommand( name string, description string, - runFunc func([]string) error, + runFunc func(context.Context, []string) error, ) *genericSubcommand { return &genericSubcommand{ nameStr: name, @@ -32,6 +34,6 @@ func (s *genericSubcommand) Description() string { return s.descriptionStr } -func (s *genericSubcommand) Run(args []string) error { - return s.runFunc(args) +func (s *genericSubcommand) Run(ctx context.Context, args []string) error { + return s.runFunc(ctx, args) } diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index c3a2c0b..c1bd31c 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -1,10 +1,12 @@ package daemon import ( + "context" "fmt" "runtime" "github.com/github/git-bundle-server/internal/common" + "github.com/github/git-bundle-server/internal/log" ) type DaemonConfig struct { @@ -15,16 +17,17 @@ type DaemonConfig struct { } type DaemonProvider interface { - Create(config *DaemonConfig, force bool) error + Create(ctx context.Context, config *DaemonConfig, force bool) error - Start(label string) error + Start(ctx context.Context, label string) error - Stop(label string) error + Stop(ctx context.Context, label string) error - Remove(label string) error + Remove(ctx context.Context, label string) error } func NewDaemonProvider( + l log.TraceLogger, u common.UserProvider, c common.CommandExecutor, fs common.FileSystem, @@ -32,10 +35,10 @@ func NewDaemonProvider( switch thisOs := runtime.GOOS; thisOs { case "linux": // Use systemd/systemctl - return NewSystemdProvider(u, c, fs), nil + return NewSystemdProvider(l, u, c, fs), nil case "darwin": // Use launchd/launchctl - return NewLaunchdProvider(u, c, fs), nil + return NewLaunchdProvider(l, u, c, fs), nil default: return nil, fmt.Errorf("cannot configure daemon handler for OS '%s'", thisOs) } diff --git a/internal/daemon/launchd.go b/internal/daemon/launchd.go index 235c86e..30e307a 100644 --- a/internal/daemon/launchd.go +++ b/internal/daemon/launchd.go @@ -2,11 +2,13 @@ package daemon import ( "bytes" + "context" "encoding/xml" "fmt" "path/filepath" "github.com/github/git-bundle-server/internal/common" + "github.com/github/git-bundle-server/internal/log" "github.com/github/git-bundle-server/internal/utils" ) @@ -90,28 +92,31 @@ func (c *launchdConfig) toPlist() *plist { } type launchd struct { + logger log.TraceLogger user common.UserProvider cmdExec common.CommandExecutor fileSystem common.FileSystem } func NewLaunchdProvider( + l log.TraceLogger, u common.UserProvider, c common.CommandExecutor, fs common.FileSystem, ) DaemonProvider { return &launchd{ + logger: l, user: u, cmdExec: c, fileSystem: fs, } } -func (l *launchd) isBootstrapped(serviceTarget string) (bool, error) { +func (l *launchd) isBootstrapped(ctx context.Context, serviceTarget string) (bool, error) { // run 'launchctl print' on given service target to see if it exists exitCode, err := l.cmdExec.Run("launchctl", "print", serviceTarget) if err != nil { - return false, err + return false, l.logger.Error(ctx, err) } if exitCode == 0 { @@ -119,30 +124,30 @@ func (l *launchd) isBootstrapped(serviceTarget string) (bool, error) { } else if exitCode == LaunchdServiceNotFoundErrorCode { return false, nil } else { - return false, fmt.Errorf("could not determine if service '%s' is bootstrapped: "+ + return false, l.logger.Errorf(ctx, "could not determine if service '%s' is bootstrapped: "+ "'launchctl print' exited with status '%d'", serviceTarget, exitCode) } } -func (l *launchd) bootstrapFile(domain string, filename string) error { +func (l *launchd) bootstrapFile(ctx context.Context, domain string, filename string) error { // run 'launchctl bootstrap' on given domain & file exitCode, err := l.cmdExec.Run("launchctl", "bootstrap", domain, filename) if err != nil { - return err + return l.logger.Error(ctx, err) } if exitCode != 0 { - return fmt.Errorf("'launchctl bootstrap' exited with status %d", exitCode) + return l.logger.Errorf(ctx, "'launchctl bootstrap' exited with status %d", exitCode) } return nil } -func (l *launchd) bootout(serviceTarget string) (bool, error) { +func (l *launchd) bootout(ctx context.Context, serviceTarget string) (bool, error) { // run 'launchctl bootout' on given service target exitCode, err := l.cmdExec.Run("launchctl", "bootout", serviceTarget) if err != nil { - return false, err + return false, l.logger.Error(ctx, err) } if exitCode == 0 { @@ -150,11 +155,11 @@ func (l *launchd) bootout(serviceTarget string) (bool, error) { } else if exitCode == LaunchdNoSuchProcessErrorCode { return false, nil } else { - return false, fmt.Errorf("'launchctl bootout' failed with status %d", exitCode) + return false, l.logger.Errorf(ctx, "'launchctl bootout' failed with status %d", exitCode) } } -func (l *launchd) Create(config *DaemonConfig, force bool) error { +func (l *launchd) Create(ctx context.Context, config *DaemonConfig, force bool) error { // Add launchd-specific config lConfig := &launchdConfig{ DaemonConfig: *config, @@ -171,27 +176,27 @@ func (l *launchd) Create(config *DaemonConfig, force bool) error { encoder.Indent("", " ") err := encoder.Encode(lConfig.toPlist()) if err != nil { - return fmt.Errorf("could not encode plist: %w", err) + return l.logger.Errorf(ctx, "could not encode plist: %w", err) } // Check the existing file - if it's the same as the new content, do not overwrite user, err := l.user.CurrentUser() if err != nil { - return fmt.Errorf("could not get current user for launchd service: %w", err) + return l.logger.Errorf(ctx, "could not get current user for launchd service: %w", err) } filename := filepath.Join(user.HomeDir, "Library", "LaunchAgents", fmt.Sprintf("%s.plist", config.Label)) domainTarget := fmt.Sprintf(domainFormat, user.Uid) serviceTarget := fmt.Sprintf("%s/%s", domainTarget, config.Label) - alreadyLoaded, err := l.isBootstrapped(serviceTarget) + alreadyLoaded, err := l.isBootstrapped(ctx, serviceTarget) if err != nil { - return err + return l.logger.Error(ctx, err) } fileExists, err := l.fileSystem.FileExists(filename) if err != nil { - return fmt.Errorf("could not determine whether plist '%s' exists: %w", filename, err) + return l.logger.Errorf(ctx, "could not determine whether plist '%s' exists: %w", filename, err) } // If not forcing re-configuration & the service configuration is valid, @@ -202,9 +207,9 @@ func (l *launchd) Create(config *DaemonConfig, force bool) error { // Unload the service so we can reconfigure & reload if alreadyLoaded { - _, err = l.bootout(serviceTarget) + _, err = l.bootout(ctx, serviceTarget) if err != nil { - return fmt.Errorf("could not bootout daemon process '%s': %w", config.Label, err) + return l.logger.Errorf(ctx, "could not bootout daemon process '%s': %w", config.Label, err) } } @@ -213,79 +218,79 @@ func (l *launchd) Create(config *DaemonConfig, force bool) error { // TODO: only overwrite file if file contents have changed err = l.fileSystem.WriteFile(filename, newPlist.Bytes()) if err != nil { - return fmt.Errorf("unable to write plist file: %w", err) + return l.logger.Errorf(ctx, "unable to write plist file: %w", err) } } - err = l.bootstrapFile(domainTarget, filename) + err = l.bootstrapFile(ctx, domainTarget, filename) if err != nil { - return fmt.Errorf("could not bootstrap daemon process '%s': %w", config.Label, err) + return l.logger.Errorf(ctx, "could not bootstrap daemon process '%s': %w", config.Label, err) } return nil } -func (l *launchd) Start(label string) error { +func (l *launchd) Start(ctx context.Context, label string) error { user, err := l.user.CurrentUser() if err != nil { - return fmt.Errorf("could not get current user for launchd service: %w", err) + return l.logger.Errorf(ctx, "could not get current user for launchd service: %w", err) } domainTarget := fmt.Sprintf(domainFormat, user.Uid) serviceTarget := fmt.Sprintf("%s/%s", domainTarget, label) exitCode, err := l.cmdExec.Run("launchctl", "kickstart", serviceTarget) if err != nil { - return err + return l.logger.Error(ctx, err) } if exitCode != 0 { - return fmt.Errorf("'launchctl kickstart' exited with status %d", exitCode) + return l.logger.Errorf(ctx, "'launchctl kickstart' exited with status %d", exitCode) } return nil } -func (l *launchd) Stop(label string) error { +func (l *launchd) Stop(ctx context.Context, label string) error { user, err := l.user.CurrentUser() if err != nil { - return fmt.Errorf("could not get current user for launchd service: %w", err) + return l.logger.Errorf(ctx, "could not get current user for launchd service: %w", err) } domainTarget := fmt.Sprintf(domainFormat, user.Uid) serviceTarget := fmt.Sprintf("%s/%s", domainTarget, label) exitCode, err := l.cmdExec.Run("launchctl", "kill", "SIGINT", serviceTarget) if err != nil { - return err + return l.logger.Error(ctx, err) } // Don't throw an error if the service hasn't been bootstrapped if exitCode != 0 && exitCode != LaunchdServiceNotFoundErrorCode && exitCode != LaunchdNoSuchProcessErrorCode { - return fmt.Errorf("'launchctl kill' exited with status %d", exitCode) + return l.logger.Errorf(ctx, "'launchctl kill' exited with status %d", exitCode) } return nil } -func (l *launchd) Remove(label string) error { +func (l *launchd) Remove(ctx context.Context, label string) error { user, err := l.user.CurrentUser() if err != nil { - return fmt.Errorf("could not get current user for launchd service: %w", err) + return l.logger.Errorf(ctx, "could not get current user for launchd service: %w", err) } filename := filepath.Join(user.HomeDir, "Library", "LaunchAgents", fmt.Sprintf("%s.plist", label)) domainTarget := fmt.Sprintf(domainFormat, user.Uid) serviceTarget := fmt.Sprintf("%s/%s", domainTarget, label) - _, err = l.bootout(serviceTarget) + _, err = l.bootout(ctx, serviceTarget) if err != nil { - return fmt.Errorf("could not remove daemon process '%s': %w", label, err) + return l.logger.Errorf(ctx, "could not remove daemon process '%s': %w", label, err) } _, err = l.fileSystem.DeleteFile(filename) if err != nil { - return fmt.Errorf("could not delete launchd plist: %w", err) + return l.logger.Errorf(ctx, "could not delete launchd plist: %w", err) } return nil diff --git a/internal/daemon/launchd_test.go b/internal/daemon/launchd_test.go index 59d1084..b190dfc 100644 --- a/internal/daemon/launchd_test.go +++ b/internal/daemon/launchd_test.go @@ -1,6 +1,7 @@ package daemon_test import ( + "context" "encoding/xml" "fmt" "os/user" @@ -225,6 +226,7 @@ var launchdCreatePlistTests = []struct { func TestLaunchd_Create(t *testing.T) { // Set up mocks + testLogger := &MockTraceLogger{} testUser := &user.User{ Uid: "123", Username: "testuser", @@ -237,7 +239,9 @@ func TestLaunchd_Create(t *testing.T) { testFileSystem := &MockFileSystem{} - launchd := daemon.NewLaunchdProvider(testUserProvider, testCommandExecutor, testFileSystem) + ctx := context.Background() + + launchd := daemon.NewLaunchdProvider(testLogger, testUserProvider, testCommandExecutor, testFileSystem) // Verify launchd commands called for _, tt := range launchdCreateBehaviorTests { @@ -276,7 +280,7 @@ func TestLaunchd_Create(t *testing.T) { } // Run "Create" - err := launchd.Create(tt.config, force) + err := launchd.Create(ctx, tt.config, force) // Assert on expected values if tt.expectErr { @@ -325,7 +329,7 @@ func TestLaunchd_Create(t *testing.T) { }), ).Return(nil).Once() - err := launchd.Create(tt.config, false) + err := launchd.Create(ctx, tt.config, false) assert.Nil(t, err) mock.AssertExpectationsForObjects(t, testCommandExecutor, testFileSystem) @@ -353,6 +357,7 @@ func TestLaunchd_Create(t *testing.T) { func TestLaunchd_Start(t *testing.T) { // Set up mocks + testLogger := &MockTraceLogger{} testUser := &user.User{ Uid: "123", Username: "testuser", @@ -362,7 +367,9 @@ func TestLaunchd_Start(t *testing.T) { testCommandExecutor := &MockCommandExecutor{} - launchd := daemon.NewLaunchdProvider(testUserProvider, testCommandExecutor, nil) + ctx := context.Background() + + launchd := daemon.NewLaunchdProvider(testLogger, testUserProvider, testCommandExecutor, nil) // Test #1: launchctl succeeds t.Run("Calls correct launchctl command", func(t *testing.T) { @@ -371,7 +378,7 @@ func TestLaunchd_Start(t *testing.T) { []string{"kickstart", fmt.Sprintf("user/123/%s", basicDaemonConfig.Label)}, ).Return(0, nil).Once() - err := launchd.Start(basicDaemonConfig.Label) + err := launchd.Start(ctx, basicDaemonConfig.Label) assert.Nil(t, err) mock.AssertExpectationsForObjects(t, testCommandExecutor) }) @@ -386,7 +393,7 @@ func TestLaunchd_Start(t *testing.T) { mock.AnythingOfType("[]string"), ).Return(1, nil).Once() - err := launchd.Start(basicDaemonConfig.Label) + err := launchd.Start(ctx, basicDaemonConfig.Label) assert.NotNil(t, err) mock.AssertExpectationsForObjects(t, testCommandExecutor) }) @@ -432,6 +439,7 @@ var launchdStopTests = []struct { func TestLaunchd_Stop(t *testing.T) { // Set up mocks + testLogger := &MockTraceLogger{} testUser := &user.User{ Uid: "123", Username: "testuser", @@ -441,7 +449,9 @@ func TestLaunchd_Stop(t *testing.T) { testCommandExecutor := &MockCommandExecutor{} - launchd := daemon.NewLaunchdProvider(testUserProvider, testCommandExecutor, nil) + ctx := context.Background() + + launchd := daemon.NewLaunchdProvider(testLogger, testUserProvider, testCommandExecutor, nil) for _, tt := range launchdStopTests { t.Run(tt.title, func(t *testing.T) { @@ -454,7 +464,7 @@ func TestLaunchd_Stop(t *testing.T) { } // Call function - err := launchd.Stop(tt.label) + err := launchd.Stop(ctx, tt.label) mock.AssertExpectationsForObjects(t, testCommandExecutor) if tt.expectErr { assert.NotNil(t, err) @@ -520,6 +530,7 @@ var launchdRemoveTests = []struct { func TestLaunchd_Remove(t *testing.T) { // Set up mocks + testLogger := &MockTraceLogger{} testUser := &user.User{ Uid: "123", Username: "testuser", @@ -531,7 +542,9 @@ func TestLaunchd_Remove(t *testing.T) { testCommandExecutor := &MockCommandExecutor{} testFileSystem := &MockFileSystem{} - launchd := daemon.NewLaunchdProvider(testUserProvider, testCommandExecutor, testFileSystem) + ctx := context.Background() + + launchd := daemon.NewLaunchdProvider(testLogger, testUserProvider, testCommandExecutor, testFileSystem) for _, tt := range launchdRemoveTests { t.Run(tt.title, func(t *testing.T) { @@ -552,7 +565,7 @@ func TestLaunchd_Remove(t *testing.T) { } // Call function - err := launchd.Remove(tt.label) + err := launchd.Remove(ctx, tt.label) mock.AssertExpectationsForObjects(t, testCommandExecutor) if tt.expectErr { assert.NotNil(t, err) diff --git a/internal/daemon/systemd.go b/internal/daemon/systemd.go index 545277e..6e97465 100644 --- a/internal/daemon/systemd.go +++ b/internal/daemon/systemd.go @@ -2,12 +2,14 @@ package daemon import ( "bytes" + "context" "fmt" "path/filepath" "strings" "text/template" "github.com/github/git-bundle-server/internal/common" + "github.com/github/git-bundle-server/internal/log" ) const serviceTemplate string = `[Unit] @@ -21,40 +23,43 @@ ExecStart={{sq_escape .Program}}{{range .Arguments}} {{sq_escape .}}{{end}} const SystemdUnitNotInstalledErrorCode int = 5 type systemd struct { + logger log.TraceLogger user common.UserProvider cmdExec common.CommandExecutor fileSystem common.FileSystem } func NewSystemdProvider( + l log.TraceLogger, u common.UserProvider, c common.CommandExecutor, fs common.FileSystem, ) DaemonProvider { return &systemd{ + logger: l, user: u, cmdExec: c, fileSystem: fs, } } -func (s *systemd) reloadDaemon() error { +func (s *systemd) reloadDaemon(ctx context.Context) error { exitCode, err := s.cmdExec.Run("systemctl", "--user", "daemon-reload") if err != nil { - return err + return s.logger.Error(ctx, err) } if exitCode != 0 { - return fmt.Errorf("'systemctl --user daemon-reload' exited with status %d", exitCode) + return s.logger.Errorf(ctx, "'systemctl --user daemon-reload' exited with status %d", exitCode) } return nil } -func (s *systemd) Create(config *DaemonConfig, force bool) error { +func (s *systemd) Create(ctx context.Context, config *DaemonConfig, force bool) error { user, err := s.user.CurrentUser() if err != nil { - return fmt.Errorf("could not get current user for systemd service: %w", err) + return s.logger.Errorf(ctx, "could not get current user for systemd service: %w", err) } // Generate the configuration @@ -65,7 +70,7 @@ func (s *systemd) Create(config *DaemonConfig, force bool) error { }, }).Parse(serviceTemplate) if err != nil { - return fmt.Errorf("unable to generate systemd configuration: %w", err) + return s.logger.Errorf(ctx, "unable to generate systemd configuration: %w", err) } t.Execute(&newServiceUnit, config) @@ -74,7 +79,7 @@ func (s *systemd) Create(config *DaemonConfig, force bool) error { // Check whether the file exists fileExists, err := s.fileSystem.FileExists(filename) if err != nil { - return fmt.Errorf("could not determine whether service unit '%s' exists: %w", config.Label, err) + return s.logger.Errorf(ctx, "could not determine whether service unit '%s' exists: %w", config.Label, err) } if !force && fileExists { @@ -85,62 +90,62 @@ func (s *systemd) Create(config *DaemonConfig, force bool) error { // Otherwise, write the new file err = s.fileSystem.WriteFile(filename, newServiceUnit.Bytes()) if err != nil { - return fmt.Errorf("unable to write service unit: %w", err) + return s.logger.Errorf(ctx, "unable to write service unit: %w", err) } // Reload the user-scoped service units after adding - err = s.reloadDaemon() + err = s.reloadDaemon(ctx) if err != nil { - return err + return s.logger.Error(ctx, err) } return nil } -func (s *systemd) Start(label string) error { +func (s *systemd) Start(ctx context.Context, label string) error { // TODO: warn user if already running exitCode, err := s.cmdExec.Run("systemctl", "--user", "start", label) if err != nil { - return err + return s.logger.Error(ctx, err) } if exitCode != 0 { - return fmt.Errorf("'systemctl stop' exited with status %d", exitCode) + return s.logger.Errorf(ctx, "'systemctl stop' exited with status %d", exitCode) } return nil } -func (s *systemd) Stop(label string) error { +func (s *systemd) Stop(ctx context.Context, label string) error { // TODO: warn user if already stopped exitCode, err := s.cmdExec.Run("systemctl", "--user", "stop", label) if err != nil { - return err + return s.logger.Error(ctx, err) } if exitCode != 0 && exitCode != SystemdUnitNotInstalledErrorCode { - return fmt.Errorf("'systemctl stop' exited with status %d", exitCode) + return s.logger.Errorf(ctx, "'systemctl stop' exited with status %d", exitCode) } return nil } -func (s *systemd) Remove(label string) error { +func (s *systemd) Remove(ctx context.Context, label string) error { user, err := s.user.CurrentUser() if err != nil { - return fmt.Errorf("could not get current user for launchd service: %w", err) + return s.logger.Errorf(ctx, "could not get current user for launchd service: %w", err) } filename := filepath.Join(user.HomeDir, ".config", "systemd", "user", fmt.Sprintf("%s.service", label)) _, err = s.fileSystem.DeleteFile(filename) if err != nil { - return fmt.Errorf("could not delete service unit: %w", err) + return s.logger.Errorf(ctx, "could not delete service unit: %w", err) } // Reload the user-scoped service units after removing - err = s.reloadDaemon() + err = s.reloadDaemon(ctx) if err != nil { - return err + return s.logger.Error(ctx, err) } return nil diff --git a/internal/daemon/systemd_test.go b/internal/daemon/systemd_test.go index 0ecc11a..e41923c 100644 --- a/internal/daemon/systemd_test.go +++ b/internal/daemon/systemd_test.go @@ -1,6 +1,7 @@ package daemon_test import ( + "context" "fmt" "os/user" "path/filepath" @@ -125,6 +126,7 @@ var systemdCreateServiceUnitTests = []struct { func TestSystemd_Create(t *testing.T) { // Set up mocks + testLogger := &MockTraceLogger{} testUser := &user.User{ Uid: "123", Username: "testuser", @@ -137,7 +139,9 @@ func TestSystemd_Create(t *testing.T) { testFileSystem := &MockFileSystem{} - systemd := daemon.NewSystemdProvider(testUserProvider, testCommandExecutor, testFileSystem) + ctx := context.Background() + + systemd := daemon.NewSystemdProvider(testLogger, testUserProvider, testCommandExecutor, testFileSystem) for _, tt := range systemdCreateBehaviorTests { forceArg := tt.force.ToBoolList() @@ -163,7 +167,7 @@ func TestSystemd_Create(t *testing.T) { } // Run "Create" - err := systemd.Create(tt.config, force) + err := systemd.Create(ctx, tt.config, force) // Assert on expected values if tt.expectErr { @@ -208,7 +212,7 @@ func TestSystemd_Create(t *testing.T) { }), ).Return(nil).Once() - err := systemd.Create(tt.config, false) + err := systemd.Create(ctx, tt.config, false) assert.Nil(t, err) mock.AssertExpectationsForObjects(t, testCommandExecutor, testFileSystem) @@ -233,6 +237,7 @@ func TestSystemd_Create(t *testing.T) { func TestSystemd_Start(t *testing.T) { // Set up mocks + testLogger := &MockTraceLogger{} testUser := &user.User{ Uid: "123", Username: "testuser", @@ -243,7 +248,9 @@ func TestSystemd_Start(t *testing.T) { testCommandExecutor := &MockCommandExecutor{} - systemd := daemon.NewSystemdProvider(testUserProvider, testCommandExecutor, nil) + ctx := context.Background() + + systemd := daemon.NewSystemdProvider(testLogger, testUserProvider, testCommandExecutor, nil) // Test #1: systemctl succeeds t.Run("Calls correct systemctl command", func(t *testing.T) { @@ -252,7 +259,7 @@ func TestSystemd_Start(t *testing.T) { []string{"--user", "start", basicDaemonConfig.Label}, ).Return(0, nil).Once() - err := systemd.Start(basicDaemonConfig.Label) + err := systemd.Start(ctx, basicDaemonConfig.Label) assert.Nil(t, err) mock.AssertExpectationsForObjects(t, testCommandExecutor) }) @@ -267,7 +274,7 @@ func TestSystemd_Start(t *testing.T) { mock.AnythingOfType("[]string"), ).Return(1, nil).Once() - err := systemd.Start(basicDaemonConfig.Label) + err := systemd.Start(ctx, basicDaemonConfig.Label) assert.NotNil(t, err) mock.AssertExpectationsForObjects(t, testCommandExecutor) }) @@ -275,6 +282,7 @@ func TestSystemd_Start(t *testing.T) { func TestSystemd_Stop(t *testing.T) { // Set up mocks + testLogger := &MockTraceLogger{} testUser := &user.User{ Uid: "123", Username: "testuser", @@ -285,7 +293,9 @@ func TestSystemd_Stop(t *testing.T) { testCommandExecutor := &MockCommandExecutor{} - systemd := daemon.NewSystemdProvider(testUserProvider, testCommandExecutor, nil) + ctx := context.Background() + + systemd := daemon.NewSystemdProvider(testLogger, testUserProvider, testCommandExecutor, nil) // Test #1: systemctl succeeds t.Run("Calls correct systemctl command", func(t *testing.T) { @@ -294,7 +304,7 @@ func TestSystemd_Stop(t *testing.T) { []string{"--user", "stop", basicDaemonConfig.Label}, ).Return(0, nil).Once() - err := systemd.Stop(basicDaemonConfig.Label) + err := systemd.Stop(ctx, basicDaemonConfig.Label) assert.Nil(t, err) mock.AssertExpectationsForObjects(t, testCommandExecutor) }) @@ -309,7 +319,7 @@ func TestSystemd_Stop(t *testing.T) { mock.AnythingOfType("[]string"), ).Return(1, nil).Once() - err := systemd.Stop(basicDaemonConfig.Label) + err := systemd.Stop(ctx, basicDaemonConfig.Label) assert.NotNil(t, err) mock.AssertExpectationsForObjects(t, testCommandExecutor) }) @@ -324,7 +334,7 @@ func TestSystemd_Stop(t *testing.T) { mock.AnythingOfType("[]string"), ).Return(daemon.SystemdUnitNotInstalledErrorCode, nil).Once() - err := systemd.Stop(basicDaemonConfig.Label) + err := systemd.Stop(ctx, basicDaemonConfig.Label) assert.Nil(t, err) mock.AssertExpectationsForObjects(t, testCommandExecutor) }) @@ -368,6 +378,7 @@ var systemdRemoveTests = []struct { func TestSystemd_Remove(t *testing.T) { // Set up mocks + testLogger := &MockTraceLogger{} testUser := &user.User{ Uid: "123", Username: "testuser", @@ -379,7 +390,9 @@ func TestSystemd_Remove(t *testing.T) { testCommandExecutor := &MockCommandExecutor{} testFileSystem := &MockFileSystem{} - systemd := daemon.NewSystemdProvider(testUserProvider, testCommandExecutor, testFileSystem) + ctx := context.Background() + + systemd := daemon.NewSystemdProvider(testLogger, testUserProvider, testCommandExecutor, testFileSystem) for _, tt := range systemdRemoveTests { t.Run(tt.title, func(t *testing.T) { @@ -400,7 +413,7 @@ func TestSystemd_Remove(t *testing.T) { } // Call function - err := systemd.Remove(tt.label) + err := systemd.Remove(ctx, tt.label) mock.AssertExpectationsForObjects(t, testCommandExecutor) if tt.expectErr { assert.NotNil(t, err) diff --git a/internal/log/logger.go b/internal/log/logger.go new file mode 100644 index 0000000..41ff3c2 --- /dev/null +++ b/internal/log/logger.go @@ -0,0 +1,53 @@ +package log + +import ( + "context" + "fmt" + "os" + "runtime/debug" +) + +// Type alias used to keep track of whether an error was already logged deeper +// in the call stack. +type loggedError error + +type TraceLogger interface { + LogCommand(ctx context.Context, commandName string) context.Context + Error(ctx context.Context, err error) error + Errorf(ctx context.Context, format string, a ...any) error + Exit(ctx context.Context, exitCode int) + Fatal(ctx context.Context, err error) + Fatalf(ctx context.Context, format string, a ...any) +} + +type traceLoggerInternal interface { + // Internal setup/teardown functions + logStart(ctx context.Context) context.Context + logExit(ctx context.Context, exitCode int) + + TraceLogger +} + +func WithTraceLogger( + ctx context.Context, + mainFunc func(context.Context, TraceLogger), +) { + logger := NewTrace2() + + // Set up the program-level context + ctx = logger.logStart(ctx) + defer func() { + if panicInfo := recover(); panicInfo != nil { + // Panicking - log, print panic info, then exit + logger.logExit(ctx, 1) + os.Stderr.WriteString(fmt.Sprintf("panic: %s\n\n", panicInfo)) + debug.PrintStack() + os.Exit(1) + } else { + // Just log the exit (but don't os.Exit()) so we can exit normally + logger.logExit(ctx, 0) + } + }() + + mainFunc(ctx, logger) +} diff --git a/internal/log/trace2.go b/internal/log/trace2.go new file mode 100644 index 0000000..789dd55 --- /dev/null +++ b/internal/log/trace2.go @@ -0,0 +1,222 @@ +package log + +import ( + "context" + "fmt" + "os" + "path" + "path/filepath" + "runtime" + "strconv" + "time" + + "github.com/google/uuid" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +// Trace2 environment variables +const ( + // TODO: handle GIT_TRACE2 by adding a separate output config (see zapcore + // "AdvancedConfiguration" example: + // https://pkg.go.dev/go.uber.org/zap#example-package-AdvancedConfiguration) + trace2Event string = "GIT_TRACE2_EVENT" +) + +// Global start time +var globalStart = time.Now().UTC() + +const trace2TimeFormat string = "2006-01-02T15:04:05.999999Z" + +type ctxKey int + +const ( + sidId ctxKey = iota +) + +type Trace2 struct { + logger *zap.Logger +} + +func getTrace2OutputPaths(envKey string) []string { + tr2Output := os.Getenv(envKey) + + // Configure the output + if tr2, err := strconv.Atoi(tr2Output); err == nil { + // Handle numeric values + if tr2 == 1 { + return []string{"stderr"} + } + // TODO: handle file handles 2-9 and unix sockets + } else if tr2Output != "" { + // Assume we received a path + fileInfo, err := os.Stat(tr2Output) + if err == nil && fileInfo.IsDir() { + // If the path is an existing directory, generate a filename + return []string{ + filepath.Join(tr2Output, fmt.Sprintf("trace2_%s.txt", globalStart.Format(trace2TimeFormat))), + } + } else { + // Create leading directories + parentDir := path.Dir(tr2Output) + os.MkdirAll(parentDir, 0o755) + return []string{tr2Output} + } + } + + return []string{} +} + +func createTrace2ZapLogger() *zap.Logger { + loggerConfig := zap.NewProductionConfig() + + // Configure the output for GIT_TRACE2_EVENT + loggerConfig.OutputPaths = getTrace2OutputPaths(trace2Event) + loggerConfig.Level = zap.NewAtomicLevelAt(zap.DebugLevel) + + // Encode UTC time + loggerConfig.EncoderConfig.TimeKey = "time" + loggerConfig.EncoderConfig.EncodeTime = zapcore.TimeEncoder( + func(t time.Time, enc zapcore.PrimitiveArrayEncoder) { + enc.AppendString(t.Format(trace2TimeFormat)) + }, + ) + + // Re-purpose the "message" to represent the (always-present) "event" key + loggerConfig.EncoderConfig.MessageKey = "event" + + // Don't print the log level + loggerConfig.EncoderConfig.LevelKey = "" + + // Disable caller info, we'll customize those fields manually + logger, _ := loggerConfig.Build(zap.WithCaller(false)) + return logger +} + +func NewTrace2() traceLoggerInternal { + return &Trace2{ + logger: createTrace2ZapLogger(), + } +} + +type fieldList []zap.Field + +func (l fieldList) withTime() fieldList { + return append(l, zap.Float64("t_abs", time.Since(globalStart).Seconds())) +} + +func (l fieldList) with(f ...zap.Field) fieldList { + return append(l, f...) +} + +func (t *Trace2) sharedFields(ctx context.Context) (context.Context, fieldList) { + fields := fieldList{} + + // Get the session ID + var sid uuid.UUID + haveSid := false + sidAny := ctx.Value(sidId) + if sidAny != nil { + sid, haveSid = sidAny.(uuid.UUID) + } + if !haveSid { + sid = uuid.New() + ctx = context.WithValue(ctx, sidId, sid) + } + fields = append(fields, zap.String("sid", sid.String())) + + // Hardcode the thread to "main" because Go doesn't like to share its + // internal info about threading. + fields = append(fields, zap.String("thread", "main")) + + // Get the caller of the function in trace2.go + // Skip up two levels: + // 0: this function + // 1: the caller of this function (StartTrace, LogEvent, etc.) + // 2: the function calling this trace2 library + _, fileName, lineNum, ok := runtime.Caller(2) + if ok { + fields = append(fields, + zap.String("file", filepath.Base(fileName)), + zap.Int("line", lineNum), + ) + } + + return ctx, fields +} + +func (t *Trace2) logStart(ctx context.Context) context.Context { + ctx, sharedFields := t.sharedFields(ctx) + + t.logger.Info("start", sharedFields.withTime().with( + zap.Strings("argv", os.Args), + )...) + + return ctx +} + +func (t *Trace2) logExit(ctx context.Context, exitCode int) { + _, sharedFields := t.sharedFields(ctx) + fields := sharedFields.with( + zap.Int("code", exitCode), + ) + t.logger.Info("exit", fields.withTime()...) + t.logger.Info("atexit", fields.withTime()...) + + t.logger.Sync() +} + +func (t *Trace2) LogCommand(ctx context.Context, commandName string) context.Context { + ctx, sharedFields := t.sharedFields(ctx) + + t.logger.Info("cmd_name", sharedFields.with(zap.String("name", commandName))...) + + return ctx +} + +func (t *Trace2) Error(ctx context.Context, err error) error { + // We only want to log the error if it's not already logged deeper in the + // call stack. + if _, ok := err.(loggedError); !ok { + _, sharedFields := t.sharedFields(ctx) + t.logger.Error("error", sharedFields.with( + zap.String("msg", err.Error()), + zap.String("fmt", err.Error()))...) + } + return loggedError(err) +} + +func (t *Trace2) Errorf(ctx context.Context, format string, a ...any) error { + // We only want to log the error if it's not already logged deeper in the + // call stack. + isLogged := false + for _, fmtArg := range a { + if _, ok := fmtArg.(loggedError); ok { + isLogged = true + break + } + } + + err := loggedError(fmt.Errorf(format, a...)) + + if isLogged { + _, sharedFields := t.sharedFields(ctx) + t.logger.Info("error", sharedFields.with( + zap.String("msg", err.Error()), + zap.String("fmt", format))...) + } + return err +} + +func (t *Trace2) Exit(ctx context.Context, exitCode int) { + t.logExit(ctx, exitCode) + os.Exit(exitCode) +} + +func (t *Trace2) Fatal(ctx context.Context, err error) { + t.Exit(ctx, 1) +} + +func (t *Trace2) Fatalf(ctx context.Context, format string, a ...any) { + t.Exit(ctx, 1) +} diff --git a/internal/testhelpers/mocks.go b/internal/testhelpers/mocks.go index 9f8dc90..6481b7e 100644 --- a/internal/testhelpers/mocks.go +++ b/internal/testhelpers/mocks.go @@ -1,11 +1,113 @@ package testhelpers import ( + "context" + "fmt" "os/user" + "runtime" "github.com/stretchr/testify/mock" ) +func methodIsMocked(m *mock.Mock) bool { + // Get the calling method name + pc := make([]uintptr, 1) + n := runtime.Callers(1, pc) + if n == 0 { + // No caller found - fall back on "not mocked" + return false + } + caller := runtime.FuncForPC(pc[0] - 1) + if caller == nil { + // Caller not found - fall back on "not mocked" + return false + } + + for _, call := range m.ExpectedCalls { + if call.Method == caller.Name() { + return true + } + } + + return false +} + +type notMocked struct{} + +var NotMockedValue notMocked = notMocked{} + +func mockWithDefault[T any](args mock.Arguments, index int, defaultValue T) T { + if len(args) <= index { + return defaultValue + } + + mockedValue := args.Get(index) + if _, ok := mockedValue.(notMocked); ok { + return defaultValue + } + + return mockedValue.(T) +} + +type MockTraceLogger struct { + mock.Mock +} + +func (l *MockTraceLogger) Region(ctx context.Context, category string, label string) (context.Context, func()) { + fnArgs := mock.Arguments{} + if methodIsMocked(&l.Mock) { + fnArgs = l.Called(ctx, category, label) + } + return mockWithDefault(fnArgs, 0, ctx), mockWithDefault(fnArgs, 1, func() {}) +} + +func (l *MockTraceLogger) LogCommand(ctx context.Context, commandName string) context.Context { + fnArgs := mock.Arguments{} + if methodIsMocked(&l.Mock) { + fnArgs = l.Called(ctx, commandName) + } + return mockWithDefault(fnArgs, 0, ctx) +} + +func (l *MockTraceLogger) Error(ctx context.Context, err error) error { + // Input validation + if err == nil { + panic("err must be nil") + } + + fnArgs := mock.Arguments{} + if methodIsMocked(&l.Mock) { + fnArgs = l.Called(ctx, err) + } + return mockWithDefault(fnArgs, 0, err) +} + +func (l *MockTraceLogger) Errorf(ctx context.Context, format string, a ...any) error { + fnArgs := mock.Arguments{} + if methodIsMocked(&l.Mock) { + fnArgs = l.Called(ctx, format, a) + } + return mockWithDefault(fnArgs, 0, fmt.Errorf(format, a...)) +} + +func (l *MockTraceLogger) Exit(ctx context.Context, exitCode int) { + if methodIsMocked(&l.Mock) { + l.Called(ctx, exitCode) + } +} + +func (l *MockTraceLogger) Fatal(ctx context.Context, err error) { + if methodIsMocked(&l.Mock) { + l.Called(ctx, err) + } +} + +func (l *MockTraceLogger) Fatalf(ctx context.Context, format string, a ...any) { + if methodIsMocked(&l.Mock) { + l.Called(ctx, format, a) + } +} + type MockUserProvider struct { mock.Mock }