diff --git a/lib/srv/git/forward.go b/lib/srv/git/forward.go index a706537a2da97..44609550e5cdb 100644 --- a/lib/srv/git/forward.go +++ b/lib/srv/git/forward.go @@ -520,7 +520,7 @@ func verifyRemoteHost(targetServer types.Server) ssh.HostKeyCallback { return func(hostname string, remote net.Addr, key ssh.PublicKey) error { switch targetServer.GetSubKind() { case types.SubKindGitHub: - return githubFingerprints.checkServerKey(key) + return githubServerKeys.check(key) default: return trace.BadParameter("unsupported subkind %q", targetServer.GetSubKind()) } diff --git a/lib/srv/git/github.go b/lib/srv/git/github.go index ff298de86f5f0..9592e3ee8c8ba 100644 --- a/lib/srv/git/github.go +++ b/lib/srv/git/github.go @@ -22,9 +22,9 @@ import ( "context" "encoding/json" "log/slog" - "maps" "net/http" "slices" + "strings" "sync" "time" @@ -42,133 +42,94 @@ import ( "github.com/gravitational/teleport/lib/sshutils" ) -type githubMetadataClient interface { - fetchETag() (string, error) - fetchFingerprints() ([]string, string, error) +// githubServerKeyManager downloads SSH keys from the GitHub meta API and does a +// lazy refresh every hour. The keys are used to verify GitHub server when +// forwarding Git commands to it. +type githubServerKeyManager struct { + mu sync.Mutex + keys []string + lastCheck time.Time + etag string + + clock clockwork.Clock + apiEndpoint string + client *http.Client } -// githubFingerprintManager downloads SSH fingerprints from the GitHub meta API -// and does a lazy refresh every hour. The fingerprints are used to verify -// GitHub server when forwarding Git commands to it. -type githubFingerprintManager struct { - mu sync.RWMutex - fingerprints []string - lastCheck time.Time - etag string - - clock clockwork.Clock - client githubMetadataClient -} - -func newGithubFingerprintManager() *githubFingerprintManager { - return &githubFingerprintManager{ - clock: clockwork.NewRealClock(), - client: newGithubMetadataTTPClient(), +func newGitHubServeKeyManager() *githubServerKeyManager { + return &githubServerKeyManager{ + clock: clockwork.NewRealClock(), + apiEndpoint: "https://api.github.com/meta", + client: &http.Client{ + Timeout: defaults.HTTPRequestTimeout, + }, } } -func (g *githubFingerprintManager) checkServerKey(key ssh.PublicKey) error { - actualFingerprint := ssh.FingerprintSHA256(key) - for _, fingerprint := range g.getKnownFingerprints() { - if sshutils.EqualFingerprints(actualFingerprint, fingerprint) { - return nil - } - } - return trace.BadParameter("cannot verify github.com: unknown fingerprint %v algo %v", actualFingerprint, key.Type()) -} +func (m *githubServerKeyManager) check(targetKey ssh.PublicKey) error { + m.mu.Lock() + defer m.mu.Unlock() -func (g *githubFingerprintManager) getKnownFingerprints() []string { - const refreshDuration = time.Hour - g.mu.RLock() - if g.clock.Now().Sub(g.lastCheck) < refreshDuration { - defer g.mu.RUnlock() - return g.fingerprints + // Refresh every 24 hours. + if m.clock.Now().Sub(m.lastCheck) > time.Hour*24 { + m.refreshLocked() } - g.mu.RUnlock() - g.mu.Lock() - defer g.mu.Unlock() - if g.clock.Now().Sub(g.lastCheck) < refreshDuration { - return g.fingerprints + // Remove newline from ssh.MarshalAuthorizedKey. + key := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(targetKey))) + if slices.Contains(m.keys, key) { + return nil } + return trace.BadParameter("cannot verify github.com: unknown server key %q", key) +} +func (m *githubServerKeyManager) refreshLocked() { + ctx := context.Background() logger := slog.With(teleport.ComponentKey, "git:github") - // Check if eTag is the same to avoid downloading the whole thing which - // contains a lot of irrelevant info. - if g.etag != "" { - etag, err := g.client.fetchETag() - switch { - case err != nil: - logger.WarnContext(context.Background(), "Failed to fetch eTag from GitHub meta API", "error", err) - // Don't give up yet if HEAD fails. - - case etag == g.etag: - g.lastCheck = g.clock.Now() - logger.DebugContext(context.Background(), "ETag did not change for GitHub meta API") - return g.fingerprints - - default: - logger.DebugContext(context.Background(), "ETag changed for GitHub meta API", "new", etag) - } - } - - fingerprints, etag, err := g.client.fetchFingerprints() + // Meta API reference: + // https://docs.github.com/en/rest/meta/meta#get-github-meta-information + req, err := http.NewRequest("GET", m.apiEndpoint, nil) if err != nil { - logger.WarnContext(context.Background(), "Failed to fetch fingerprints from GitHub meta API", "error", err) - return g.fingerprints + logger.WarnContext(ctx, "Failed to make request for GitHub meta API", "error", err) + return } - logger.DebugContext(context.Background(), "Found SSH fingerprints from Github meta API", "fingerprints", fingerprints, "etag", etag) - g.etag = etag - g.fingerprints = fingerprints - g.lastCheck = g.clock.Now() - return g.fingerprints -} - -var githubFingerprints = newGithubFingerprintManager() - -type githubMetadataHTTPClient struct { - api string - client *http.Client -} -func newGithubMetadataTTPClient() *githubMetadataHTTPClient { - return &githubMetadataHTTPClient{ - api: "https://api.github.com/meta", - client: &http.Client{ - Timeout: defaults.HTTPRequestTimeout, - }, + // ETag check. + if m.etag != "" { + req.Header.Set("If-None-Match", m.etag) } -} -func (c *githubMetadataHTTPClient) fetchETag() (string, error) { - resp, err := http.Head(c.api) + resp, err := m.client.Do(req) if err != nil { - return "", trace.Wrap(err) + logger.WarnContext(ctx, "Failed to fetch GitHub meta API", "error", err) + return } - return resp.Header.Get("ETag"), nil -} + defer resp.Body.Close() -func (c *githubMetadataHTTPClient) fetchFingerprints() ([]string, string, error) { - resp, err := http.Get(c.api) - if err != nil { - return nil, "", trace.Wrap(err) + // Nothing changed. Just update the last check time. + if resp.StatusCode == http.StatusNotModified { + logger.DebugContext(ctx, "GitHub metadata is up-to-date") + m.lastCheck = m.clock.Now() + return } - defer resp.Body.Close() - // Meta API reference: - // https://docs.github.com/en/rest/meta/meta?apiVersion=2022-11-28#get-github-meta-information meta := struct { - // Fingerprints lists the fingerprints by algo type. - Fingerprints map[string]string `json:"ssh_key_fingerprints"` + SSHKeys []string `json:"ssh_keys"` }{} if err := json.NewDecoder(resp.Body).Decode(&meta); err != nil { - return nil, "", trace.Wrap(err) + logger.WarnContext(ctx, "Failed to decode response from GitHub meta API", "error", err) + return } - return slices.Collect(maps.Values(meta.Fingerprints)), resp.Header.Get("ETag"), nil + m.etag = resp.Header.Get("ETag") + m.keys = meta.SSHKeys + m.lastCheck = m.clock.Now() + logger.DebugContext(ctx, "Fetched GitHub metadata", "ssh_keys", m.keys, "etag", m.etag) } +var githubServerKeys = newGitHubServeKeyManager() + // AuthPreferenceGetter is an interface for retrieving the current configured // cluster auth preference. type AuthPreferenceGetter interface { diff --git a/lib/srv/git/github_test.go b/lib/srv/git/github_test.go index 4a5ff77d523b7..c938c4e086445 100644 --- a/lib/srv/git/github_test.go +++ b/lib/srv/git/github_test.go @@ -21,11 +21,16 @@ package git import ( "context" "crypto/rand" + "encoding/json" + "net/http" + "net/http/httptest" "testing" "time" + "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" "google.golang.org/grpc" @@ -33,6 +38,7 @@ import ( integrationv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/integration/v1" "github.com/gravitational/teleport/api/types" apisshutils "github.com/gravitational/teleport/api/utils/sshutils" + "github.com/gravitational/teleport/lib/fixtures" ) type fakeAuthPreferenceGetter struct { @@ -143,3 +149,58 @@ func TestMakeGitHubSigner(t *testing.T) { }) } } + +func Test_githubServerKeyManager(t *testing.T) { + clock := clockwork.NewFakeClock() + etag := uuid.NewString() + serverCalled := 0 + serverCalledWithETagMatch := 0 + body, err := json.Marshal(map[string][]string{ + "ssh_keys": []string{fixtures.SSHCAPublicKey}, + }) + require.NoError(t, err) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + serverCalled++ + if req.Header.Get("If-None-Match") == etag { + serverCalledWithETagMatch++ + w.WriteHeader(http.StatusNotModified) + return + } + w.Header().Add("Etag", etag) + w.WriteHeader(http.StatusOK) + w.Write(body) + })) + t.Cleanup(server.Close) + + validKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fixtures.SSHCAPublicKey)) + require.NoError(t, err) + invalidKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(`ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBGv+gN2C23P08ieJRA9gU/Ik4bsOh3Kw193UYscJDw41mATj+Kqyf45Rmj8F8rs3i7mYKRXXu1IjNRBzNgpXxqc=`)) + require.NoError(t, err) + + m := newGitHubServeKeyManager() + m.apiEndpoint = server.URL + m.clock = clock + + // First check should download from the server. + require.NoError(t, m.check(validKey)) + require.Error(t, m.check(invalidKey)) + assert.Equal(t, 1, serverCalled) + assert.Equal(t, 0, serverCalledWithETagMatch) + + // Advance time for a refresh. + clock.Advance(time.Hour * 30) + require.NoError(t, m.check(validKey)) + require.Error(t, m.check(invalidKey)) + assert.Equal(t, 2, serverCalled) + assert.Equal(t, 1, serverCalledWithETagMatch) + + // Simulate an etag change and advance time for a refresh. + m.keys = nil + m.etag = "does-not-match" + clock.Advance(time.Hour * 30) + require.NoError(t, m.check(validKey)) + require.Error(t, m.check(invalidKey)) + assert.Equal(t, 3, serverCalled) + assert.Equal(t, 1, serverCalledWithETagMatch) +}