Skip to content

Commit 49b7e1d

Browse files
authored
Merge pull request #10 from github/vdye/https
Add option for serving over HTTPS
2 parents f64602c + 5377d04 commit 49b7e1d

File tree

10 files changed

+598
-140
lines changed

10 files changed

+598
-140
lines changed

cmd/git-bundle-server/web-server.go

+39
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@ package main
22

33
import (
44
"errors"
5+
"flag"
56
"fmt"
67
"os"
78
"os/exec"
89
"path/filepath"
910

11+
"github.com/github/git-bundle-server/cmd/utils"
1012
"github.com/github/git-bundle-server/internal/argparse"
1113
"github.com/github/git-bundle-server/internal/common"
1214
"github.com/github/git-bundle-server/internal/daemon"
@@ -77,9 +79,19 @@ func (w *webServer) getDaemonConfig() (*daemon.DaemonConfig, error) {
7779
func (w *webServer) startServer(args []string) error {
7880
// Parse subcommand arguments
7981
parser := argparse.NewArgParser("git-bundle-server web-server start [-f|--force]")
82+
83+
// Args for 'git-bundle-server web-server start'
8084
force := parser.Bool("force", false, "Whether to force reconfiguration of the web server daemon")
8185
parser.BoolVar(force, "f", false, "Alias of --force")
86+
87+
// Arguments passed through to 'git-bundle-web-server'
88+
webServerFlags, validate := utils.WebServerFlags(parser)
89+
webServerFlags.VisitAll(func(f *flag.Flag) {
90+
parser.Var(f.Value, f.Name, fmt.Sprintf("[Web server] %s", f.Usage))
91+
})
92+
8293
parser.Parse(args)
94+
validate()
8395

8496
d, err := daemon.NewDaemonProvider(w.user, w.cmdExec, w.fileSystem)
8597
if err != nil {
@@ -91,6 +103,33 @@ func (w *webServer) startServer(args []string) error {
91103
return err
92104
}
93105

106+
// Configure flags
107+
loopErr := error(nil)
108+
parser.Visit(func(f *flag.Flag) {
109+
if webServerFlags.Lookup(f.Name) != nil {
110+
value := f.Value.String()
111+
if f.Name == "cert" || f.Name == "key" {
112+
// Need the absolute value of the path
113+
value, err = filepath.Abs(value)
114+
if err != nil {
115+
if loopErr == nil {
116+
// NEEDSWORK: Only report the first error because Go
117+
// doesn't like it when you manually chain errors :(
118+
// Luckily, this is slated to change in v1.20, per
119+
// https://tip.golang.org/doc/go1.20#errors
120+
loopErr = fmt.Errorf("could not get absolute path of '%s': %w", f.Name, err)
121+
}
122+
return
123+
}
124+
}
125+
config.Arguments = append(config.Arguments, fmt.Sprintf("--%s", f.Name), value)
126+
}
127+
})
128+
if loopErr != nil {
129+
// Error happened in 'Visit'
130+
return loopErr
131+
}
132+
94133
err = d.Create(config, *force)
95134
if err != nil {
96135
return err

cmd/git-bundle-web-server/main.go

+46-33
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"context"
5+
"flag"
56
"fmt"
67
"log"
78
"net/http"
@@ -11,34 +12,25 @@ import (
1112
"sync"
1213
"syscall"
1314

15+
"github.com/github/git-bundle-server/cmd/utils"
16+
"github.com/github/git-bundle-server/internal/argparse"
1417
"github.com/github/git-bundle-server/internal/core"
1518
)
1619

1720
func parseRoute(path string) (string, string, string, error) {
18-
if len(path) == 0 {
21+
elements := strings.FieldsFunc(path, func(char rune) bool { return char == '/' })
22+
switch len(elements) {
23+
case 0:
1924
return "", "", "", fmt.Errorf("empty route")
20-
}
21-
22-
if path[0] == '/' {
23-
path = path[1:]
24-
}
25-
26-
slash1 := strings.Index(path, "/")
27-
if slash1 < 0 {
25+
case 1:
2826
return "", "", "", fmt.Errorf("route has owner, but no repo")
29-
}
30-
slash2 := strings.Index(path[slash1+1:], "/")
31-
if slash2 < 0 {
32-
// No trailing slash.
33-
return path[:slash1], path[slash1+1:], "", nil
34-
}
35-
slash2 += slash1 + 1
36-
slash3 := strings.Index(path[slash2+1:], "/")
37-
if slash3 >= 0 {
27+
case 2:
28+
return elements[0], elements[1], "", nil
29+
case 3:
30+
return elements[0], elements[1], elements[2], nil
31+
default:
3832
return "", "", "", fmt.Errorf("path has depth exceeding three")
3933
}
40-
41-
return path[:slash1], path[slash1+1 : slash2], path[slash2+1:], nil
4234
}
4335

4436
func serve(w http.ResponseWriter, r *http.Request) {
@@ -83,36 +75,57 @@ func serve(w http.ResponseWriter, r *http.Request) {
8375
w.Write(data)
8476
}
8577

86-
func createAndStartServer(address string, serverWaitGroup *sync.WaitGroup) *http.Server {
87-
// Create the HTTP server
88-
server := &http.Server{Addr: address}
89-
90-
// API routes
91-
http.HandleFunc("/", serve)
92-
78+
func startServer(server *http.Server,
79+
cert string, key string,
80+
serverWaitGroup *sync.WaitGroup,
81+
) {
9382
// Add to wait group
9483
serverWaitGroup.Add(1)
9584

9685
go func() {
9786
defer serverWaitGroup.Done()
9887

9988
// Return error unless it indicates graceful shutdown
100-
err := server.ListenAndServe()
101-
if err != http.ErrServerClosed {
89+
var err error
90+
if cert != "" {
91+
err = server.ListenAndServeTLS(cert, key)
92+
} else {
93+
err = server.ListenAndServe()
94+
}
95+
96+
if err != nil && err != http.ErrServerClosed {
10297
log.Fatal(err)
10398
}
10499
}()
105100

106-
fmt.Println("Server is running at address " + address)
107-
return server
101+
fmt.Println("Server is running at address " + server.Addr)
108102
}
109103

110104
func main() {
105+
parser := argparse.NewArgParser("git-bundle-web-server [--port <port>] [--cert <filename> --key <filename>]")
106+
flags, validate := utils.WebServerFlags(parser)
107+
flags.VisitAll(func(f *flag.Flag) {
108+
parser.Var(f.Value, f.Name, f.Usage)
109+
})
110+
parser.Parse(os.Args[1:])
111+
validate()
112+
113+
// Get the flag values
114+
port := utils.GetFlagValue[string](parser, "port")
115+
cert := utils.GetFlagValue[string](parser, "cert")
116+
key := utils.GetFlagValue[string](parser, "key")
117+
118+
// Configure the server
119+
mux := http.NewServeMux()
120+
mux.HandleFunc("/", serve)
121+
server := &http.Server{
122+
Handler: mux,
123+
Addr: ":" + port,
124+
}
111125
serverWaitGroup := &sync.WaitGroup{}
112126

113127
// Start the server asynchronously
114-
port := ":8080"
115-
server := createAndStartServer(port, serverWaitGroup)
128+
startServer(server, cert, key, serverWaitGroup)
116129

117130
// Intercept interrupt signals
118131
c := make(chan os.Signal, 1)

cmd/utils/common-args.go

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package utils
2+
3+
import (
4+
"flag"
5+
"fmt"
6+
"strconv"
7+
)
8+
9+
// Helpers
10+
11+
// Forward declaration (kinda) of argParser.
12+
// 'argparse.argParser' is private, but we want to be able to pass instances
13+
// of it to functions, so we need to define an interface that includes the
14+
// functions we want to call from the parser.
15+
type argParser interface {
16+
Lookup(name string) *flag.Flag
17+
Usage(errFmt string, args ...any)
18+
}
19+
20+
func GetFlagValue[T any](parser argParser, name string) T {
21+
flagVal := parser.Lookup(name)
22+
if flagVal == nil {
23+
panic(fmt.Sprintf("flag '--%s' is undefined", name))
24+
}
25+
26+
flagGetter, ok := flagVal.Value.(flag.Getter)
27+
if !ok {
28+
panic(fmt.Sprintf("flag '--%s' is invalid (does not implement flag.Getter)", name))
29+
}
30+
31+
value, ok := flagGetter.Get().(T)
32+
if !ok {
33+
panic(fmt.Sprintf("flag '--%s' is invalid (cannot cast to appropriate type)", name))
34+
}
35+
36+
return value
37+
}
38+
39+
// Sets of flags shared between multiple commands/programs
40+
41+
func WebServerFlags(parser argParser) (*flag.FlagSet, func()) {
42+
f := flag.NewFlagSet("", flag.ContinueOnError)
43+
port := f.String("port", "8080", "The port on which the server should be hosted")
44+
cert := f.String("cert", "", "The path to the X.509 SSL certificate file to use in securely hosting the server")
45+
key := f.String("key", "", "The path to the certificate's private key")
46+
47+
// Function to call for additional arg validation (may exit with 'Usage()')
48+
validationFunc := func() {
49+
p, err := strconv.Atoi(*port)
50+
if err != nil || p < 0 || p > 65535 {
51+
parser.Usage("Invalid port '%s'.", *port)
52+
}
53+
if (*cert == "") != (*key == "") {
54+
parser.Usage("Both '--cert' and '--key' are needed to specify SSL configuration.")
55+
}
56+
}
57+
58+
return f, validationFunc
59+
}

internal/daemon/daemon.go

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ type DaemonConfig struct {
1111
Label string
1212
Description string
1313
Program string
14+
Arguments []string
1415
}
1516

1617
type DaemonProvider interface {

internal/daemon/launchd.go

+73-15
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,54 @@ package daemon
22

33
import (
44
"bytes"
5+
"encoding/xml"
56
"fmt"
67
"path/filepath"
7-
"text/template"
88

99
"github.com/github/git-bundle-server/internal/common"
10+
"github.com/github/git-bundle-server/internal/utils"
1011
)
1112

12-
const launchTemplate string = `<?xml version="1.0" encoding="UTF-8"?>
13-
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
14-
<plist version="1.0">
15-
<dict>
16-
<key>Label</key><string>{{.Label}}</string>
17-
<key>Program</key><string>{{.Program}}</string>
18-
<key>StandardOutPath</key><string>{{.StdOut}}</string>
19-
<key>StandardErrorPath</key><string>{{.StdErr}}</string>
20-
</dict>
21-
</plist>
22-
`
13+
type xmlItem struct {
14+
XMLName xml.Name
15+
Value string `xml:",chardata"`
16+
}
17+
18+
type xmlArray struct {
19+
XMLName xml.Name
20+
Elements []interface{} `xml:",any"`
21+
}
22+
23+
type plist struct {
24+
XMLName xml.Name `xml:"plist"`
25+
Version string `xml:"version,attr"`
26+
Config xmlArray `xml:"dict"`
27+
}
28+
29+
const plistHeader string = `<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">`
30+
31+
func xmlName(name string) xml.Name {
32+
return xml.Name{Local: name}
33+
}
34+
35+
func (p *plist) addKeyValue(key string, value any) {
36+
p.Config.Elements = append(p.Config.Elements, xmlItem{XMLName: xmlName("key"), Value: key})
37+
switch value := value.(type) {
38+
case string:
39+
p.Config.Elements = append(p.Config.Elements, xmlItem{XMLName: xmlName("string"), Value: value})
40+
case []string:
41+
p.Config.Elements = append(p.Config.Elements,
42+
xmlArray{
43+
XMLName: xmlName("array"),
44+
Elements: utils.Map(value, func(e string) interface{} {
45+
return xmlItem{XMLName: xmlName("string"), Value: e}
46+
}),
47+
},
48+
)
49+
default:
50+
panic("Invalid value type in 'addKeyValue'")
51+
}
52+
}
2353

2454
const domainFormat string = "gui/%s"
2555

@@ -31,6 +61,31 @@ type launchdConfig struct {
3161
StdErr string
3262
}
3363

64+
func (c *launchdConfig) toPlist() *plist {
65+
p := &plist{
66+
Version: "1.0",
67+
Config: xmlArray{Elements: []interface{}{}},
68+
}
69+
p.addKeyValue("Label", c.Label)
70+
p.addKeyValue("Program", c.Program)
71+
p.addKeyValue("StandardOutPath", c.StdOut)
72+
p.addKeyValue("StandardErrorPath", c.StdErr)
73+
74+
// IMPORTANT!!!
75+
// You must explicitly set the first argument to the executable path
76+
// because 'ProgramArguments' maps directly 'argv' in 'execvp'. The
77+
// programs calling this library likely will, by convention, assume the
78+
// first element of 'argv' is the executing program.
79+
// See https://www.unix.com/man-page/osx/5/launchd.plist/ and
80+
// https://man7.org/linux/man-pages/man3/execvp.3.html for more details.
81+
args := make([]string, len(c.Arguments)+1)
82+
args[0] = c.Program
83+
copy(args[1:], c.Arguments[:])
84+
p.addKeyValue("ProgramArguments", args)
85+
86+
return p
87+
}
88+
3489
type launchd struct {
3590
user common.UserProvider
3691
cmdExec common.CommandExecutor
@@ -104,11 +159,14 @@ func (l *launchd) Create(config *DaemonConfig, force bool) error {
104159

105160
// Generate the configuration
106161
var newPlist bytes.Buffer
107-
t, err := template.New(config.Label).Parse(launchTemplate)
162+
newPlist.WriteString(xml.Header)
163+
newPlist.WriteString(plistHeader)
164+
encoder := xml.NewEncoder(&newPlist)
165+
encoder.Indent("", " ")
166+
err := encoder.Encode(lConfig.toPlist())
108167
if err != nil {
109-
return fmt.Errorf("unable to generate launchd configuration: %w", err)
168+
return fmt.Errorf("could not encode plist: %w", err)
110169
}
111-
t.Execute(&newPlist, lConfig)
112170

113171
// Check the existing file - if it's the same as the new content, do not overwrite
114172
user, err := l.user.CurrentUser()

0 commit comments

Comments
 (0)