diff --git a/cmd/git-bundle-server/delete.go b/cmd/git-bundle-server/delete.go index 5833341..1d2a6e3 100644 --- a/cmd/git-bundle-server/delete.go +++ b/cmd/git-bundle-server/delete.go @@ -4,18 +4,21 @@ import ( "context" "os" + "github.com/github/git-bundle-server/cmd/utils" "github.com/github/git-bundle-server/internal/argparse" "github.com/github/git-bundle-server/internal/core" "github.com/github/git-bundle-server/internal/log" ) type deleteCmd struct { - logger log.TraceLogger + logger log.TraceLogger + container *utils.DependencyContainer } -func NewDeleteCommand(logger log.TraceLogger) argparse.Subcommand { +func NewDeleteCommand(logger log.TraceLogger, container *utils.DependencyContainer) argparse.Subcommand { return &deleteCmd{ - logger: logger, + logger: logger, + container: container, } } @@ -34,12 +37,14 @@ func (d *deleteCmd) Run(ctx context.Context, args []string) error { route := parser.PositionalString("route", "the route to delete") parser.Parse(ctx, args) - repo, err := core.CreateRepository(*route) + repoProvider := utils.GetDependency[core.RepositoryProvider](ctx, d.container) + + repo, err := repoProvider.CreateRepository(ctx, *route) if err != nil { return d.logger.Error(ctx, err) } - err = core.RemoveRoute(*route) + err = repoProvider.RemoveRoute(ctx, *route) if err != nil { return d.logger.Error(ctx, err) } diff --git a/cmd/git-bundle-server/init.go b/cmd/git-bundle-server/init.go index 2a1d3bd..fc75312 100644 --- a/cmd/git-bundle-server/init.go +++ b/cmd/git-bundle-server/init.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/github/git-bundle-server/cmd/utils" "github.com/github/git-bundle-server/internal/argparse" "github.com/github/git-bundle-server/internal/bundles" "github.com/github/git-bundle-server/internal/core" @@ -12,12 +13,14 @@ import ( ) type initCmd struct { - logger log.TraceLogger + logger log.TraceLogger + container *utils.DependencyContainer } -func NewInitCommand(logger log.TraceLogger) argparse.Subcommand { +func NewInitCommand(logger log.TraceLogger, container *utils.DependencyContainer) argparse.Subcommand { return &initCmd{ - logger: logger, + logger: logger, + container: container, } } @@ -38,7 +41,10 @@ func (i *initCmd) Run(ctx context.Context, args []string) error { route := parser.PositionalString("route", "the route to host the specified repo") parser.Parse(ctx, args) - repo, err := core.CreateRepository(*route) + repoProvider := utils.GetDependency[core.RepositoryProvider](ctx, i.container) + bundleProvider := utils.GetDependency[bundles.BundleProvider](ctx, i.container) + + repo, err := repoProvider.CreateRepository(ctx, *route) if err != nil { return i.logger.Error(ctx, err) } @@ -60,7 +66,7 @@ func (i *initCmd) Run(ctx context.Context, args []string) error { return i.logger.Errorf(ctx, "failed to fetch latest refs: %w", gitErr) } - bundle := bundles.CreateInitialBundle(repo) + bundle := bundleProvider.CreateInitialBundle(ctx, repo) fmt.Printf("Constructing base bundle file at %s\n", bundle.Filename) written, gitErr := git.CreateBundle(repo.RepoDir, bundle.Filename) @@ -71,8 +77,8 @@ func (i *initCmd) Run(ctx context.Context, args []string) error { return i.logger.Errorf(ctx, "refused to write empty bundle. Is the repo empty?") } - list := bundles.CreateSingletonList(bundle) - listErr := bundles.WriteBundleList(list, repo) + list := bundleProvider.CreateSingletonList(ctx, bundle) + listErr := bundleProvider.WriteBundleList(ctx, list, repo) if listErr != nil { return i.logger.Errorf(ctx, "failed to write bundle list: %w", listErr) } diff --git a/cmd/git-bundle-server/main.go b/cmd/git-bundle-server/main.go index ba6bbf2..e9d9f16 100644 --- a/cmd/git-bundle-server/main.go +++ b/cmd/git-bundle-server/main.go @@ -4,19 +4,22 @@ import ( "context" "os" + "github.com/github/git-bundle-server/cmd/utils" "github.com/github/git-bundle-server/internal/argparse" "github.com/github/git-bundle-server/internal/log" ) func all(logger log.TraceLogger) []argparse.Subcommand { + container := utils.BuildGitBundleServerContainer(logger) + return []argparse.Subcommand{ - NewDeleteCommand(logger), - NewInitCommand(logger), - NewStartCommand(logger), - NewStopCommand(logger), - NewUpdateCommand(logger), - NewUpdateAllCommand(logger), - NewWebServerCommand(logger), + NewDeleteCommand(logger, container), + NewInitCommand(logger, container), + NewStartCommand(logger, container), + NewStopCommand(logger, container), + NewUpdateCommand(logger, container), + NewUpdateAllCommand(logger, container), + NewWebServerCommand(logger, container), } } diff --git a/cmd/git-bundle-server/start.go b/cmd/git-bundle-server/start.go index 8dea0e8..31b235b 100644 --- a/cmd/git-bundle-server/start.go +++ b/cmd/git-bundle-server/start.go @@ -4,18 +4,21 @@ import ( "context" "os" + "github.com/github/git-bundle-server/cmd/utils" "github.com/github/git-bundle-server/internal/argparse" "github.com/github/git-bundle-server/internal/core" "github.com/github/git-bundle-server/internal/log" ) type startCmd struct { - logger log.TraceLogger + logger log.TraceLogger + container *utils.DependencyContainer } -func NewStartCommand(logger log.TraceLogger) argparse.Subcommand { +func NewStartCommand(logger log.TraceLogger, container *utils.DependencyContainer) argparse.Subcommand { return &startCmd{ - logger: logger, + logger: logger, + container: container, } } @@ -34,8 +37,10 @@ func (s *startCmd) Run(ctx context.Context, args []string) error { route := parser.PositionalString("route", "the route for which bundles should be generated") parser.Parse(ctx, args) + repoProvider := utils.GetDependency[core.RepositoryProvider](ctx, s.container) + // CreateRepository registers the route. - repo, err := core.CreateRepository(*route) + repo, err := repoProvider.CreateRepository(ctx, *route) if err != nil { return s.logger.Error(ctx, err) } diff --git a/cmd/git-bundle-server/stop.go b/cmd/git-bundle-server/stop.go index 3fa591b..37e0beb 100644 --- a/cmd/git-bundle-server/stop.go +++ b/cmd/git-bundle-server/stop.go @@ -3,18 +3,21 @@ package main import ( "context" + "github.com/github/git-bundle-server/cmd/utils" "github.com/github/git-bundle-server/internal/argparse" "github.com/github/git-bundle-server/internal/core" "github.com/github/git-bundle-server/internal/log" ) type stopCmd struct { - logger log.TraceLogger + logger log.TraceLogger + container *utils.DependencyContainer } -func NewStopCommand(logger log.TraceLogger) argparse.Subcommand { +func NewStopCommand(logger log.TraceLogger, container *utils.DependencyContainer) argparse.Subcommand { return &stopCmd{ - logger: logger, + logger: logger, + container: container, } } @@ -33,7 +36,9 @@ func (s *stopCmd) Run(ctx context.Context, args []string) error { route := parser.PositionalString("route", "the route for which bundles should stop being generated") parser.Parse(ctx, args) - err := core.RemoveRoute(*route) + repoProvider := utils.GetDependency[core.RepositoryProvider](ctx, s.container) + + err := repoProvider.RemoveRoute(ctx, *route) if err != nil { s.logger.Error(ctx, err) } diff --git a/cmd/git-bundle-server/update-all.go b/cmd/git-bundle-server/update-all.go index 432c39b..51bada7 100644 --- a/cmd/git-bundle-server/update-all.go +++ b/cmd/git-bundle-server/update-all.go @@ -5,19 +5,21 @@ import ( "os" "os/exec" + "github.com/github/git-bundle-server/cmd/utils" "github.com/github/git-bundle-server/internal/argparse" - "github.com/github/git-bundle-server/internal/common" "github.com/github/git-bundle-server/internal/core" "github.com/github/git-bundle-server/internal/log" ) type updateAllCmd struct { - logger log.TraceLogger + logger log.TraceLogger + container *utils.DependencyContainer } -func NewUpdateAllCommand(logger log.TraceLogger) argparse.Subcommand { +func NewUpdateAllCommand(logger log.TraceLogger, container *utils.DependencyContainer) argparse.Subcommand { return &updateAllCmd{ - logger: logger, + logger: logger, + container: container, } } @@ -31,23 +33,19 @@ For every configured route, run 'git-bundle-server update '.` } func (u *updateAllCmd) Run(ctx context.Context, args []string) error { - user, err := common.NewUserProvider().CurrentUser() - if err != nil { - return u.logger.Error(ctx, err) - } - fs := common.NewFileSystem() - parser := argparse.NewArgParser(u.logger, "git-bundle-server update-all") parser.Parse(ctx, args) - exe, err := os.Executable() + repoProvider := utils.GetDependency[core.RepositoryProvider](ctx, u.container) + + repos, err := repoProvider.GetRepositories(ctx) if err != nil { - return u.logger.Errorf(ctx, "failed to get path to execuable: %w", err) + return u.logger.Error(ctx, err) } - repos, err := core.GetRepositories(user, fs) + exe, err := os.Executable() if err != nil { - return u.logger.Error(ctx, err) + return u.logger.Errorf(ctx, "failed to get path to execuable: %w", err) } subargs := []string{"update", ""} diff --git a/cmd/git-bundle-server/update.go b/cmd/git-bundle-server/update.go index 13b0060..75e0ea5 100644 --- a/cmd/git-bundle-server/update.go +++ b/cmd/git-bundle-server/update.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + "github.com/github/git-bundle-server/cmd/utils" "github.com/github/git-bundle-server/internal/argparse" "github.com/github/git-bundle-server/internal/bundles" "github.com/github/git-bundle-server/internal/core" @@ -11,12 +12,14 @@ import ( ) type updateCmd struct { - logger log.TraceLogger + logger log.TraceLogger + container *utils.DependencyContainer } -func NewUpdateCommand(logger log.TraceLogger) argparse.Subcommand { +func NewUpdateCommand(logger log.TraceLogger, container *utils.DependencyContainer) argparse.Subcommand { return &updateCmd{ - logger: logger, + logger: logger, + container: container, } } @@ -36,18 +39,21 @@ func (u *updateCmd) Run(ctx context.Context, args []string) error { route := parser.PositionalString("route", "the route to update") parser.Parse(ctx, args) - repo, err := core.CreateRepository(*route) + repoProvider := utils.GetDependency[core.RepositoryProvider](ctx, u.container) + bundleProvider := utils.GetDependency[bundles.BundleProvider](ctx, u.container) + + repo, err := repoProvider.CreateRepository(ctx, *route) if err != nil { return u.logger.Error(ctx, err) } - list, err := bundles.GetBundleList(repo) + list, err := bundleProvider.GetBundleList(ctx, repo) if err != nil { return u.logger.Errorf(ctx, "failed to load bundle list: %w", err) } fmt.Printf("Creating new incremental bundle\n") - bundle, err := bundles.CreateIncrementalBundle(repo, list) + bundle, err := bundleProvider.CreateIncrementalBundle(ctx, repo, list) if err != nil { return u.logger.Error(ctx, err) } @@ -60,13 +66,13 @@ func (u *updateCmd) Run(ctx context.Context, args []string) error { list.Bundles[bundle.CreationToken] = *bundle fmt.Printf("Collapsing bundle list\n") - err = bundles.CollapseList(repo, list) + err = bundleProvider.CollapseList(ctx, repo, list) if err != nil { return u.logger.Error(ctx, err) } fmt.Printf("Writing updated bundle list\n") - listErr := bundles.WriteBundleList(list, repo) + listErr := bundleProvider.WriteBundleList(ctx, list, repo) if listErr != nil { return u.logger.Errorf(ctx, "failed to write bundle list: %w", listErr) } diff --git a/cmd/git-bundle-server/web-server.go b/cmd/git-bundle-server/web-server.go index 912089a..0e8e927 100644 --- a/cmd/git-bundle-server/web-server.go +++ b/cmd/git-bundle-server/web-server.go @@ -17,19 +17,14 @@ import ( ) type webServerCmd struct { - logger log.TraceLogger - user common.UserProvider - cmdExec common.CommandExecutor - fileSystem common.FileSystem + logger log.TraceLogger + container *utils.DependencyContainer } -func NewWebServerCommand(logger log.TraceLogger) argparse.Subcommand { - // Create subcommand-specific dependencies +func NewWebServerCommand(logger log.TraceLogger, container *utils.DependencyContainer) argparse.Subcommand { return &webServerCmd{ - logger: logger, - user: common.NewUserProvider(), - cmdExec: common.NewCommandExecutor(), - fileSystem: common.NewFileSystem(), + logger: logger, + container: container, } } @@ -64,7 +59,8 @@ func (w *webServerCmd) getDaemonConfig(ctx context.Context) (*daemon.DaemonConfi } programPath = filepath.Join(exeDir, "git-bundle-web-server") - programExists, err := w.fileSystem.FileExists(programPath) + fileSystem := utils.GetDependency[common.FileSystem](ctx, w.container) + programExists, err := fileSystem.FileExists(programPath) if err != nil { return nil, w.logger.Errorf(ctx, "could not determine whether path to 'git-bundle-web-server' exists: %w", err) } else if !programExists { @@ -97,10 +93,7 @@ func (w *webServerCmd) startServer(ctx context.Context, args []string) error { parser.Parse(ctx, args) validate(ctx) - d, err := daemon.NewDaemonProvider(w.logger, w.user, w.cmdExec, w.fileSystem) - if err != nil { - return w.logger.Error(ctx, err) - } + d := utils.GetDependency[daemon.DaemonProvider](ctx, w.container) config, err := w.getDaemonConfig(ctx) if err != nil { @@ -153,10 +146,7 @@ func (w *webServerCmd) stopServer(ctx context.Context, args []string) error { remove := parser.Bool("remove", false, "Remove the web server daemon configuration from the system after stopping") parser.Parse(ctx, args) - d, err := daemon.NewDaemonProvider(w.logger, w.user, w.cmdExec, w.fileSystem) - if err != nil { - return w.logger.Error(ctx, err) - } + d := utils.GetDependency[daemon.DaemonProvider](ctx, w.container) config, err := w.getDaemonConfig(ctx) if err != nil { diff --git a/cmd/git-bundle-web-server/bundle-server.go b/cmd/git-bundle-web-server/bundle-server.go index 78cb50d..8f6a158 100644 --- a/cmd/git-bundle-web-server/bundle-server.go +++ b/cmd/git-bundle-web-server/bundle-server.go @@ -66,13 +66,10 @@ func (b *bundleWebServer) parseRoute(ctx context.Context, path string) (string, func (b *bundleWebServer) serve(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - user, err := common.NewUserProvider().CurrentUser() - if err != nil { - return - } - fs := common.NewFileSystem() - path := r.URL.Path + ctx, exitRegion := b.logger.Region(ctx, "http", "serve") + defer exitRegion() + path := r.URL.Path owner, repo, file, err := b.parseRoute(ctx, path) if err != nil { w.WriteHeader(http.StatusNotFound) @@ -82,7 +79,11 @@ func (b *bundleWebServer) serve(w http.ResponseWriter, r *http.Request) { route := owner + "/" + repo - repos, err := core.GetRepositories(user, fs) + userProvider := common.NewUserProvider() + fileSystem := common.NewFileSystem() + repoProvider := core.NewRepositoryProvider(b.logger, userProvider, fileSystem) + + repos, err := repoProvider.GetRepositories(ctx) if err != nil { w.WriteHeader(http.StatusInternalServerError) fmt.Printf("Failed to load routes\n") diff --git a/cmd/utils/container-helpers.go b/cmd/utils/container-helpers.go new file mode 100644 index 0000000..f84febb --- /dev/null +++ b/cmd/utils/container-helpers.go @@ -0,0 +1,48 @@ +package utils + +import ( + "context" + + "github.com/github/git-bundle-server/internal/bundles" + "github.com/github/git-bundle-server/internal/common" + "github.com/github/git-bundle-server/internal/core" + "github.com/github/git-bundle-server/internal/daemon" + "github.com/github/git-bundle-server/internal/log" +) + +func BuildGitBundleServerContainer(logger log.TraceLogger) *DependencyContainer { + container := NewDependencyContainer() + registerDependency(container, func(ctx context.Context) common.UserProvider { + return common.NewUserProvider() + }) + registerDependency(container, func(ctx context.Context) common.CommandExecutor { + return common.NewCommandExecutor() + }) + registerDependency(container, func(ctx context.Context) common.FileSystem { + return common.NewFileSystem() + }) + registerDependency(container, func(ctx context.Context) core.RepositoryProvider { + return core.NewRepositoryProvider( + logger, + GetDependency[common.UserProvider](ctx, container), + GetDependency[common.FileSystem](ctx, container), + ) + }) + registerDependency(container, func(ctx context.Context) bundles.BundleProvider { + return bundles.NewBundleProvider(logger) + }) + registerDependency(container, func(ctx context.Context) daemon.DaemonProvider { + t, err := daemon.NewDaemonProvider( + logger, + GetDependency[common.UserProvider](ctx, container), + GetDependency[common.CommandExecutor](ctx, container), + GetDependency[common.FileSystem](ctx, container), + ) + if err != nil { + logger.Fatal(ctx, err) + } + return t + }) + + return container +} diff --git a/cmd/utils/container-helpers_test.go b/cmd/utils/container-helpers_test.go new file mode 100644 index 0000000..efd0226 --- /dev/null +++ b/cmd/utils/container-helpers_test.go @@ -0,0 +1,120 @@ +package utils_test + +import ( + "bytes" + "context" + "fmt" + "go/ast" + "go/parser" + "go/printer" + "go/token" + "reflect" + "strings" + "testing" + + "github.com/github/git-bundle-server/cmd/utils" + . "github.com/github/git-bundle-server/internal/testhelpers" + typeutils "github.com/github/git-bundle-server/internal/utils" + "github.com/stretchr/testify/assert" +) + +func findAllGetDependencyTypesInDir(relativePathToDir string) (*token.FileSet, []ast.Expr) { + fset := token.NewFileSet() // positions are relative to fset + pkgs, err := parser.ParseDir(fset, relativePathToDir, nil, 0) + if err != nil { + panic("could not read directory") + } + + typeNodes := []ast.Expr{} + for _, pkg := range pkgs { + ast.Inspect(pkg, func(n ast.Node) bool { + switch x := n.(type) { + // Might be an invocation of GetDependency + case *ast.IndexExpr: + fnSelector, ok := x.X.(*ast.SelectorExpr) + if !ok { + return true + } + + if fnSelector.Sel.Name == "GetDependency" { + // Now, get the identifier (or selector) for the type + typeNodes = append(typeNodes, x.Index) + } + + return true + } + + // Keep recursing! + return true + }) + } + + return fset, typeNodes +} + +// Test that all GetDependency invocations in the 'git-bundle-server' 'main' +// package are in the container built by 'BuildGitBundleServerContainer'. +// +// This test is somewhat fragile, and it isn't comprehensive. Its main utility +// is to cover easy-to-miss runtime issues that could arise from the dependency +// provider. +// +// Scenarios that can cause issues: +// - 'utils' (not to be confused with 'internal/utils') is imported with a +// dot-import in the tested file (don't do this!). +// - types requested by 'GetDependency()' belong to an explicitly aliased or +// dot-imported package. +// - invocations of the dependency container in files outside the tested +// packages (also don't do this!). +func TestDependencyContainer(t *testing.T) { + logger := &MockTraceLogger{} + ctx := context.Background() + + t.Run("Container is created successfully", func(t *testing.T) { + assert.NotPanics(t, func() { utils.BuildGitBundleServerContainer(logger) }) + }) + + t.Run("Verify container is internally consistent", func(t *testing.T) { + container := utils.BuildGitBundleServerContainer(logger) + assert.NotPanics(t, func() { container.InvokeAll(ctx) }) + }) + + t.Run("Verify all external invocations are registered", func(t *testing.T) { + container := utils.BuildGitBundleServerContainer(logger) + registeredTypes := typeutils.Map(container.ListRegisteredTypes(), + func(t reflect.Type) string { + return t.String() + }, + ) + + fset, typeNodes := findAllGetDependencyTypesInDir("../git-bundle-server") + + // We expect at least one registered dependency, otherwise get rid of + // this test. + assert.NotEmpty(t, typeNodes) + + // Ensure each node is found in the container + for _, node := range typeNodes { + var name string + if ident, ok := node.(*ast.Ident); ok { + // No package identified with the type - + pkgPath := reflect.TypeOf(*container).PkgPath() + pkgComponents := strings.Split(pkgPath, "/") + name = fmt.Sprintf("%s.%s", pkgComponents[len(pkgComponents)-1], ident.Name) + } else { + var nameBuf bytes.Buffer + err := printer.Fprint(&nameBuf, fset, node) + if err != nil { + assert.Fail(t, err.Error()) + } + name = nameBuf.String() + } + + // check if name is in registered types list + callLocation := fset.Position(node.Pos()) + assert.Contains(t, registeredTypes, name, + "Type %s was not registered; see call to 'GetDependency()' in file '%s', line %d", + name, callLocation.Filename, callLocation.Line) + } + }) +} diff --git a/cmd/utils/container.go b/cmd/utils/container.go new file mode 100644 index 0000000..18089a1 --- /dev/null +++ b/cmd/utils/container.go @@ -0,0 +1,72 @@ +package utils + +import ( + "context" + "fmt" + "reflect" + "sync" +) + +type Initializer[T interface{}] func(ctx context.Context) T + +// Contains lazily generated singletons of stateless struct used throughout +// the application. Thread-safe. +type DependencyContainer struct { + singletonInitializers *sync.Map +} + +func NewDependencyContainer() *DependencyContainer { + return &DependencyContainer{ + singletonInitializers: &sync.Map{}, + } +} + +func (d *DependencyContainer) ListRegisteredTypes() []reflect.Type { + typeList := []reflect.Type{} + d.singletonInitializers.Range(func(key any, value any) bool { + asType, ok := key.(reflect.Type) + if !ok { + panic("key to singletonInitializers was not 'reflect.Type'") + } + typeList = append(typeList, asType) + return true + }) + return typeList +} + +func (d *DependencyContainer) InvokeAll(ctx context.Context) { + d.singletonInitializers.Range(func(key any, value any) bool { + reflectValue := reflect.ValueOf(value) + reflectValue.Call([]reflect.Value{reflect.ValueOf(ctx)}) + return true + }) +} + +// These *should* be generic methods of DependencyContainer, but generic methods +// aren't supported in Go... (ノಠ益ಠ)ノ彡┻━┻ + +func wrapFuncForSingleton[T interface{}](initializer Initializer[T]) Initializer[T] { + var t T + var once sync.Once + return func(ctx context.Context) T { + once.Do(func() { + t = initializer(ctx) + }) + return t + } +} + +func registerDependency[T interface{}](d *DependencyContainer, initializer Initializer[T]) { + tType := reflect.TypeOf((*T)(nil)).Elem() + d.singletonInitializers.Store(tType, wrapFuncForSingleton(initializer)) +} + +func GetDependency[T interface{}](ctx context.Context, d *DependencyContainer) T { + tType := reflect.TypeOf((*T)(nil)).Elem() + + if initializer, ok := d.singletonInitializers.Load(tType); !ok { + panic(fmt.Sprintf("no initializer registered for type '%s'", tType)) + } else { + return initializer.(Initializer[T])(ctx) + } +} diff --git a/internal/bundles/bundles.go b/internal/bundles/bundles.go index 06dfe10..a87758c 100644 --- a/internal/bundles/bundles.go +++ b/internal/bundles/bundles.go @@ -2,6 +2,7 @@ package bundles import ( "bufio" + "context" "encoding/json" "fmt" "os" @@ -12,6 +13,7 @@ import ( "github.com/github/git-bundle-server/internal/core" "github.com/github/git-bundle-server/internal/git" + "github.com/github/git-bundle-server/internal/log" ) type BundleHeader struct { @@ -37,11 +39,42 @@ type BundleList struct { Bundles map[int64]Bundle } -func addBundleToList(bundle Bundle, list *BundleList) { +func (list *BundleList) addBundle(bundle Bundle) { list.Bundles[bundle.CreationToken] = bundle } -func CreateInitialBundle(repo *core.Repository) Bundle { +func (list *BundleList) sortedCreationTokens() []int64 { + keys := make([]int64, 0, len(list.Bundles)) + for timestamp := range list.Bundles { + keys = append(keys, timestamp) + } + + sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] }) + + return keys +} + +type BundleProvider interface { + CreateInitialBundle(ctx context.Context, repo *core.Repository) Bundle + CreateIncrementalBundle(ctx context.Context, repo *core.Repository, list *BundleList) (*Bundle, error) + + CreateSingletonList(ctx context.Context, bundle Bundle) *BundleList + WriteBundleList(ctx context.Context, list *BundleList, repo *core.Repository) error + GetBundleList(ctx context.Context, repo *core.Repository) (*BundleList, error) + CollapseList(ctx context.Context, repo *core.Repository, list *BundleList) error +} + +type bundleProvider struct { + logger log.TraceLogger +} + +func NewBundleProvider(logger log.TraceLogger) BundleProvider { + return &bundleProvider{ + logger: logger, + } +} + +func (b *bundleProvider) CreateInitialBundle(ctx context.Context, repo *core.Repository) Bundle { timestamp := time.Now().UTC().Unix() bundleName := "bundle-" + fmt.Sprint(timestamp) + ".bundle" bundleFile := repo.WebDir + "/" + bundleName @@ -54,10 +87,10 @@ func CreateInitialBundle(repo *core.Repository) Bundle { return bundle } -func CreateDistinctBundle(repo *core.Repository, list *BundleList) Bundle { +func (b *bundleProvider) createDistinctBundle(repo *core.Repository, list *BundleList) Bundle { timestamp := time.Now().UTC().Unix() - keys := GetSortedCreationTokens(list) + keys := list.sortedCreationTokens() maxTimestamp := keys[len(keys)-1] if timestamp <= maxTimestamp { @@ -75,16 +108,20 @@ func CreateDistinctBundle(repo *core.Repository, list *BundleList) Bundle { return bundle } -func CreateSingletonList(bundle Bundle) *BundleList { +func (b *bundleProvider) CreateSingletonList(ctx context.Context, bundle Bundle) *BundleList { list := BundleList{1, "all", make(map[int64]Bundle)} - addBundleToList(bundle, &list) + list.addBundle(bundle) return &list } // Given a BundleList -func WriteBundleList(list *BundleList, repo *core.Repository) error { +func (b *bundleProvider) WriteBundleList(ctx context.Context, list *BundleList, repo *core.Repository) error { + //lint:ignore SA4006 always override the ctx with the result from 'Region()' + ctx, exitRegion := b.logger.Region(ctx, "bundles", "write_bundle_list") + defer exitRegion() + listFile := repo.WebDir + "/bundle-list" jsonFile := repo.RepoDir + "/bundle-list.json" @@ -100,7 +137,7 @@ func WriteBundleList(list *BundleList, repo *core.Repository) error { out, "[bundle]\n\tversion = %d\n\tmode = %s\n\n", list.Version, list.Mode) - keys := GetSortedCreationTokens(list) + keys := list.sortedCreationTokens() for _, token := range keys { bundle := list.Bundles[token] @@ -145,7 +182,11 @@ func WriteBundleList(list *BundleList, repo *core.Repository) error { return os.Rename(listFile+".lock", listFile) } -func GetBundleList(repo *core.Repository) (*BundleList, error) { +func (b *bundleProvider) GetBundleList(ctx context.Context, repo *core.Repository) (*BundleList, error) { + //lint:ignore SA4006 always override the ctx with the result from 'Region()' + ctx, exitRegion := b.logger.Region(ctx, "bundles", "get_bundle_list") + defer exitRegion() + jsonFile := repo.RepoDir + "/bundle-list.json" reader, err := os.Open(jsonFile) @@ -162,7 +203,7 @@ func GetBundleList(repo *core.Repository) (*BundleList, error) { return &list, nil } -func GetBundleHeader(bundle Bundle) (*BundleHeader, error) { +func (b *bundleProvider) getBundleHeader(bundle Bundle) (*BundleHeader, error) { file, err := os.Open(bundle.Filename) if err != nil { return nil, fmt.Errorf("failed to open bundle file: %w", err) @@ -232,11 +273,11 @@ func GetBundleHeader(bundle Bundle) (*BundleHeader, error) { return &header, nil } -func GetAllPrereqsForIncrementalBundle(list *BundleList) ([]string, error) { +func (b *bundleProvider) getAllPrereqsForIncrementalBundle(list *BundleList) ([]string, error) { prereqs := []string{} for _, bundle := range list.Bundles { - header, err := GetBundleHeader(bundle) + header, err := b.getBundleHeader(bundle) if err != nil { return nil, fmt.Errorf("failed to parse bundle file %s: %w", bundle.Filename, err) } @@ -249,10 +290,13 @@ func GetAllPrereqsForIncrementalBundle(list *BundleList) ([]string, error) { return prereqs, nil } -func CreateIncrementalBundle(repo *core.Repository, list *BundleList) (*Bundle, error) { - bundle := CreateDistinctBundle(repo, list) +func (b *bundleProvider) CreateIncrementalBundle(ctx context.Context, repo *core.Repository, list *BundleList) (*Bundle, error) { + ctx, exitRegion := b.logger.Region(ctx, "bundles", "create_incremental_bundle") + defer exitRegion() - lines, err := GetAllPrereqsForIncrementalBundle(list) + bundle := b.createDistinctBundle(repo, list) + + lines, err := b.getAllPrereqsForIncrementalBundle(list) if err != nil { return nil, err } @@ -269,14 +313,17 @@ func CreateIncrementalBundle(repo *core.Repository, list *BundleList) (*Bundle, return &bundle, nil } -func CollapseList(repo *core.Repository, list *BundleList) error { +func (b *bundleProvider) CollapseList(ctx context.Context, repo *core.Repository, list *BundleList) error { + ctx, exitRegion := b.logger.Region(ctx, "bundles", "collapse_list") + defer exitRegion() + maxBundles := 5 if len(list.Bundles) <= maxBundles { return nil } - keys := GetSortedCreationTokens(list) + keys := list.sortedCreationTokens() refs := make(map[string]string) @@ -289,7 +336,7 @@ func CollapseList(repo *core.Repository, list *BundleList) error { maxTimestamp = bundle.CreationToken } - header, err := GetBundleHeader(bundle) + header, err := b.getBundleHeader(bundle) if err != nil { return fmt.Errorf("failed to parse bundle file %s: %w", bundle.Filename, err) } @@ -327,14 +374,3 @@ func CollapseList(repo *core.Repository, list *BundleList) error { list.Bundles[maxTimestamp] = bundle return nil } - -func GetSortedCreationTokens(list *BundleList) []int64 { - keys := make([]int64, 0, len(list.Bundles)) - for timestamp := range list.Bundles { - keys = append(keys, timestamp) - } - - sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] }) - - return keys -} diff --git a/internal/common/command.go b/internal/common/command.go index 8317505..1b8652c 100644 --- a/internal/common/command.go +++ b/internal/common/command.go @@ -11,7 +11,7 @@ type CommandExecutor interface { type commandExecutor struct{} -func NewCommandExecutor() *commandExecutor { +func NewCommandExecutor() CommandExecutor { return &commandExecutor{} } diff --git a/internal/common/filesystem.go b/internal/common/filesystem.go index 94d62d5..58970f7 100644 --- a/internal/common/filesystem.go +++ b/internal/common/filesystem.go @@ -18,7 +18,7 @@ type FileSystem interface { type fileSystem struct{} -func NewFileSystem() *fileSystem { +func NewFileSystem() FileSystem { return &fileSystem{} } diff --git a/internal/common/user.go b/internal/common/user.go index 257e911..dd582cd 100644 --- a/internal/common/user.go +++ b/internal/common/user.go @@ -10,7 +10,7 @@ type UserProvider interface { type userProvider struct{} -func NewUserProvider() *userProvider { +func NewUserProvider() UserProvider { return &userProvider{} } diff --git a/internal/core/repo.go b/internal/core/repo.go index aa93b42..5834a28 100644 --- a/internal/core/repo.go +++ b/internal/core/repo.go @@ -1,12 +1,12 @@ package core import ( + "context" "fmt" - "log" "os" - "os/user" "github.com/github/git-bundle-server/internal/common" + "github.com/github/git-bundle-server/internal/log" ) type Repository struct { @@ -15,13 +15,39 @@ type Repository struct { WebDir string } -func CreateRepository(route string) (*Repository, error) { - user, err := common.NewUserProvider().CurrentUser() +type RepositoryProvider interface { + CreateRepository(ctx context.Context, route string) (*Repository, error) + GetRepositories(ctx context.Context) (map[string]Repository, error) + RemoveRoute(ctx context.Context, route string) error +} + +type repoProvider struct { + logger log.TraceLogger + user common.UserProvider + fileSystem common.FileSystem +} + +func NewRepositoryProvider(logger log.TraceLogger, + u common.UserProvider, + fs common.FileSystem, +) RepositoryProvider { + return &repoProvider{ + logger: logger, + user: u, + fileSystem: fs, + } +} + +func (r *repoProvider) CreateRepository(ctx context.Context, route string) (*Repository, error) { + ctx, exitRegion := r.logger.Region(ctx, "repo", "create_repo") + defer exitRegion() + + user, err := r.user.CurrentUser() if err != nil { return nil, err } - fs := common.NewFileSystem() - repos, err := GetRepositories(user, fs) + + repos, err := r.GetRepositories(ctx) if err != nil { return nil, fmt.Errorf("failed to parse routes file: %w", err) } @@ -36,7 +62,7 @@ func CreateRepository(route string) (*Repository, error) { mkdirErr := os.MkdirAll(web, os.ModePerm) if mkdirErr != nil { - log.Fatal("failed to create web directory: ", mkdirErr) + return nil, fmt.Errorf("failed to create web directory: %w", mkdirErr) } repo = Repository{ @@ -47,7 +73,7 @@ func CreateRepository(route string) (*Repository, error) { repos[route] = repo - err = WriteRouteFile(repos) + err = r.writeRouteFile(repos) if err != nil { return nil, fmt.Errorf("warning: failed to write route file") } @@ -55,13 +81,11 @@ func CreateRepository(route string) (*Repository, error) { return &repo, nil } -func RemoveRoute(route string) error { - user, err := common.NewUserProvider().CurrentUser() - if err != nil { - return err - } - fs := common.NewFileSystem() - repos, err := GetRepositories(user, fs) +func (r *repoProvider) RemoveRoute(ctx context.Context, route string) error { + ctx, exitRegion := r.logger.Region(ctx, "repo", "remove_route") + defer exitRegion() + + repos, err := r.GetRepositories(ctx) if err != nil { return fmt.Errorf("failed to parse routes file: %w", err) } @@ -73,11 +97,11 @@ func RemoveRoute(route string) error { delete(repos, route) - return WriteRouteFile(repos) + return r.writeRouteFile(repos) } -func WriteRouteFile(repos map[string]Repository) error { - user, err := common.NewUserProvider().CurrentUser() +func (r *repoProvider) writeRouteFile(repos map[string]Repository) error { + user, err := r.user.CurrentUser() if err != nil { return err } @@ -93,13 +117,21 @@ func WriteRouteFile(repos map[string]Repository) error { return os.WriteFile(routefile, []byte(contents), 0o600) } -func GetRepositories(user *user.User, fs common.FileSystem) (map[string]Repository, error) { +func (r *repoProvider) GetRepositories(ctx context.Context) (map[string]Repository, error) { + ctx, exitRegion := r.logger.Region(ctx, "repo", "get_repos") //lint:ignore SA4006 keep ctx up-to-date + defer exitRegion() + + user, err := r.user.CurrentUser() + if err != nil { + return nil, err + } + repos := make(map[string]Repository) dir := bundleroot(user) routefile := dir + "/routes" - lines, err := fs.ReadFileLines(routefile) + lines, err := r.fileSystem.ReadFileLines(routefile) if err != nil { return nil, err } diff --git a/internal/core/repo_test.go b/internal/core/repo_test.go index da2d94c..326db3b 100644 --- a/internal/core/repo_test.go +++ b/internal/core/repo_test.go @@ -1,6 +1,7 @@ package core_test import ( + "context" "errors" "os/user" "testing" @@ -83,12 +84,16 @@ var getRepositoriesTests = []struct { } func TestRepos_GetRepositories(t *testing.T) { + testLogger := &MockTraceLogger{} testFileSystem := &MockFileSystem{} testUser := &user.User{ Uid: "123", Username: "testuser", HomeDir: "/my/test/dir", } + testUserProvider := &MockUserProvider{} + testUserProvider.On("CurrentUser").Return(testUser, nil) + repoProvider := core.NewRepositoryProvider(testLogger, testUserProvider, testFileSystem) for _, tt := range getRepositoriesTests { t.Run(tt.title, func(t *testing.T) { @@ -97,7 +102,7 @@ func TestRepos_GetRepositories(t *testing.T) { mock.AnythingOfType("string"), ).Return(tt.readFileLines.First, tt.readFileLines.Second).Once() - actual, err := core.GetRepositories(testUser, testFileSystem) + actual, err := repoProvider.GetRepositories(context.Background()) if tt.expectedErr { assert.NotNil(t, err, "Expected error") diff --git a/internal/log/logger.go b/internal/log/logger.go index 41ff3c2..c8a419f 100644 --- a/internal/log/logger.go +++ b/internal/log/logger.go @@ -12,6 +12,7 @@ import ( type loggedError error type TraceLogger interface { + Region(ctx context.Context, category string, label string) (context.Context, func()) LogCommand(ctx context.Context, commandName string) context.Context Error(ctx context.Context, err error) error Errorf(ctx context.Context, format string, a ...any) error diff --git a/internal/log/trace2.go b/internal/log/trace2.go index 789dd55..14600bd 100644 --- a/internal/log/trace2.go +++ b/internal/log/trace2.go @@ -32,8 +32,14 @@ type ctxKey int const ( sidId ctxKey = iota + parentRegionId ) +type trace2Region struct { + level int + tStart time.Time +} + type Trace2 struct { logger *zap.Logger } @@ -105,24 +111,51 @@ func (l fieldList) withTime() fieldList { return append(l, zap.Float64("t_abs", time.Since(globalStart).Seconds())) } +func (l fieldList) withNesting(r trace2Region, includeTRel bool) fieldList { + l = append(l, zap.Int("nesting", r.level)) + if includeTRel { + l = append(l, zap.Float64("t_rel", time.Since(r.tStart).Seconds())) + } + return l +} + func (l fieldList) with(f ...zap.Field) fieldList { return append(l, f...) } +func getContextValue[T any]( + ctx context.Context, + key ctxKey, +) (bool, T) { + var value T + haveValue := false + valueAny := ctx.Value(key) + if valueAny != nil { + value, haveValue = valueAny.(T) + } + return haveValue, value +} + +func getOrSetContextValue[T any]( + ctx context.Context, + key ctxKey, + newValueFunc func() T, +) (context.Context, T) { + var value T + haveValue, value := getContextValue[T](ctx, key) + if !haveValue { + value = newValueFunc() + ctx = context.WithValue(ctx, key, value) + } + + return ctx, value +} + func (t *Trace2) sharedFields(ctx context.Context) (context.Context, fieldList) { fields := fieldList{} // Get the session ID - var sid uuid.UUID - haveSid := false - sidAny := ctx.Value(sidId) - if sidAny != nil { - sid, haveSid = sidAny.(uuid.UUID) - } - if !haveSid { - sid = uuid.New() - ctx = context.WithValue(ctx, sidId, sid) - } + ctx, sid := getOrSetContextValue(ctx, sidId, uuid.New) fields = append(fields, zap.String("sid", sid.String())) // Hardcode the thread to "main" because Go doesn't like to share its @@ -132,7 +165,7 @@ func (t *Trace2) sharedFields(ctx context.Context) (context.Context, fieldList) // Get the caller of the function in trace2.go // Skip up two levels: // 0: this function - // 1: the caller of this function (StartTrace, LogEvent, etc.) + // 1: the caller of this function (logStart, Error, etc.) // 2: the function calling this trace2 library _, fileName, lineNum, ok := runtime.Caller(2) if ok { @@ -166,6 +199,33 @@ func (t *Trace2) logExit(ctx context.Context, exitCode int) { t.logger.Sync() } +func (t *Trace2) Region(ctx context.Context, category string, label string) (context.Context, func()) { + ctx, sharedFields := t.sharedFields(ctx) + + // Get the nesting level & increment + hasParentRegion, nesting := getContextValue[trace2Region](ctx, parentRegionId) + if !hasParentRegion { + nesting = trace2Region{ + level: 0, + tStart: time.Now(), + } + } else { + nesting.level++ + nesting.tStart = time.Now() + } + ctx = context.WithValue(ctx, parentRegionId, nesting) + + regionFields := fieldList{ + zap.String("category", category), + zap.String("label", label), + } + + t.logger.Debug("region_enter", sharedFields.withNesting(nesting, false).with(regionFields...)...) + return ctx, func() { + t.logger.Debug("region_leave", sharedFields.withNesting(nesting, true).with(regionFields...)...) + } +} + func (t *Trace2) LogCommand(ctx context.Context, commandName string) context.Context { ctx, sharedFields := t.sharedFields(ctx)