Skip to content

Commit

Permalink
feat: Enhance API route initialization and shutdown with robust error…
Browse files Browse the repository at this point in the history
… handling

- Refactor route initialization in `initRoutes()` to use a more structured and resilient approach
- Add deferred error recovery for route initialization to prevent application startup failures
- Implement a new `Shutdown()` method for graceful API controller cleanup
- Add context-based CPU monitoring with proper cancellation support
- Improve authentication middleware to handle browser and API client scenarios differently
- Enhance token extraction and validation in auth tests with more comprehensive scenarios
- Add time zone and error handling considerations in weather route methods
  • Loading branch information
tphakala committed Feb 27, 2025
1 parent 2f74f3d commit 84ad717
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 89 deletions.
66 changes: 41 additions & 25 deletions internal/api/v2/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,35 +69,40 @@ func (c *Controller) initRoutes() {
// Health check endpoint - publicly accessible
c.Group.GET("/health", c.HealthCheck)

// Initialize detection routes
c.initDetectionRoutes()

// Analytics routes - for statistics and data analysis
c.initAnalyticsRoutes()

// Weather routes - for weather data and detection conditions
c.initWeatherRoutes()

// System routes (for hardware and software information) - protected
c.initSystemRoutes()

// Settings routes (for application configuration) - protected
c.initSettingsRoutes()

// Stream routes (for real-time data) - protected
c.initStreamRoutes()
// Initialize route groups with proper error handling and logging
routeInitializers := []struct {
name string
fn func()
}{
{"detection routes", c.initDetectionRoutes},
{"analytics routes", c.initAnalyticsRoutes},
{"weather routes", c.initWeatherRoutes},
{"system routes", c.initSystemRoutes},
{"settings routes", c.initSettingsRoutes},
{"stream routes", c.initStreamRoutes},
{"integration routes", c.initIntegrationsRoutes},
{"control routes", c.initControlRoutes},
{"auth routes", c.initAuthRoutes},
{"media routes", c.initMediaRoutes},
}

// Integration routes (for external services) - protected
c.initIntegrationsRoutes()
for _, initializer := range routeInitializers {
c.Debug("Initializing %s...", initializer.name)

// Control routes (for application control) - protected
c.initControlRoutes()
// Use a deferred function to recover from panics during route initialization
func() {
defer func() {
if r := recover(); r != nil {
c.logger.Printf("PANIC during %s initialization: %v", initializer.name, r)
}
}()

// Authentication routes - partially protected based on their implementation
c.initAuthRoutes()
// Call the initializer
initializer.fn()

// Initialize media routes - protected
c.initMediaRoutes()
c.Debug("Successfully initialized %s", initializer.name)
}()
}
}

// HealthCheck handles the API health check endpoint
Expand All @@ -109,6 +114,17 @@ func (c *Controller) HealthCheck(ctx echo.Context) error {
})
}

// Shutdown performs cleanup of all resources used by the API controller
// This should be called when the application is shutting down
func (c *Controller) Shutdown() {
// Call shutdown methods of individual components
// Currently, only the system component needs cleanup
StopCPUMonitoring()

// Log shutdown
c.Debug("API Controller shutting down, CPU monitoring stopped")
}

// Error response structure
type ErrorResponse struct {
Error string `json:"error"`
Expand Down
151 changes: 107 additions & 44 deletions internal/api/v2/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http/httptest"
"os"
"strings"
"sync"
"testing"

"errors"
Expand All @@ -27,23 +28,29 @@ type SecurityManager interface {
// MockSecurityManager implements a mock for the authentication system
type MockSecurityManager struct {
mock.Mock
mu sync.Mutex // Added mutex for concurrent safety
}

// Validate the token
func (m *MockSecurityManager) ValidateToken(token string) (bool, error) {
m.mu.Lock()
defer m.mu.Unlock()
args := m.Called(token)
return args.Bool(0), args.Error(1)
}

// Generate a new token
func (m *MockSecurityManager) GenerateToken(username string) (string, error) {
m.mu.Lock()
defer m.mu.Unlock()
args := m.Called(username)
return args.String(0), args.Error(1)
}

// MockServer implements the interfaces required for auth testing
type MockServer struct {
mock.Mock
mu sync.Mutex // Added mutex for concurrent safety
AuthEnabled bool
ValidTokens map[string]bool
Password string
Expand All @@ -52,6 +59,9 @@ type MockServer struct {

// ValidateAccessToken validates an access token
func (m *MockServer) ValidateAccessToken(token string) bool {
m.mu.Lock()
defer m.mu.Unlock()

// First check if we have a direct mock expectation
if m.Mock.ExpectedCalls != nil {
for _, call := range m.Mock.ExpectedCalls {
Expand All @@ -73,33 +83,62 @@ func (m *MockServer) ValidateAccessToken(token string) bool {

// IsAccessAllowed checks if access is allowed
func (m *MockServer) IsAccessAllowed(c echo.Context) bool {
m.mu.Lock()
defer m.mu.Unlock()
args := m.Called(c)
return args.Bool(0)
}

// isAuthenticationEnabled checks if authentication is enabled
func (m *MockServer) isAuthenticationEnabled(c echo.Context) bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.AuthEnabled
}

// AuthenticateBasic performs basic authentication
func (m *MockServer) AuthenticateBasic(c echo.Context, username, password string) bool {
m.mu.Lock()
defer m.mu.Unlock()
args := m.Called(c, username, password)
return args.Bool(0)
}

// GetUsername returns the authenticated username
func (m *MockServer) GetUsername(c echo.Context) string {
m.mu.Lock()
defer m.mu.Unlock()
args := m.Called(c)
return args.String(0)
}

// GetAuthMethod returns the authentication method
func (m *MockServer) GetAuthMethod(c echo.Context) string {
m.mu.Lock()
defer m.mu.Unlock()
args := m.Called(c)
return args.String(0)
}

// extractTokenFromContext is a utility function to consistently extract tokens
// from either context or authorization header
func extractTokenFromContext(c echo.Context) string {
// First check if token was set directly in context
if tokenVal := c.Get("token"); tokenVal != nil {
if token, ok := tokenVal.(string); ok {
return token
}
}

// Next, try to extract from Authorization header (Bearer token)
authHeader := c.Request().Header.Get("Authorization")
if authHeader != "" && strings.HasPrefix(authHeader, "Bearer ") {
return strings.TrimPrefix(authHeader, "Bearer ")
}

return ""
}

// TestAuthMiddleware tests the authentication middleware
func TestAuthMiddleware(t *testing.T) {
// Setup
Expand Down Expand Up @@ -166,6 +205,14 @@ func TestAuthMiddleware(t *testing.T) {
expectStatus: http.StatusUnauthorized,
serverSetup: nil,
},
{
name: "Syntactically corrupted token",
token: "invalid.jwt.format-missing-segments",
validateReturn: false,
validateError: errors.New("invalid token format"),
expectStatus: http.StatusUnauthorized,
serverSetup: nil,
},
}

for _, tc := range testCases {
Expand Down Expand Up @@ -350,47 +397,6 @@ func TestLogin(t *testing.T) {
}
}

// ValidateToken implements a test token validation function directly within test context
func validateToken(c echo.Context, token string) (bool, error) {
// Get server from context which should contain our mock
server := c.Get("server")
if server == nil {
return false, errors.New("server not available in context")
}

// Try to use the mock server's security manager
if mockServer, ok := server.(*MockServer); ok && mockServer.Security != nil {
return mockServer.Security.ValidateToken(token)
}

return false, errors.New("validation failed")
}

// mockValidateToken is an implementation for the test handler
func mockValidateToken(c echo.Context) error {
token := c.Get("token").(string)

if token == "" {
return echo.NewHTTPError(http.StatusBadRequest, "Token is required")
}

valid, err := validateToken(c, token)

// Handle specific error types
switch {
case err != nil && err.Error() == "token expired":
return echo.NewHTTPError(http.StatusUnauthorized, "Token expired")
case err != nil && err.Error() == "missing claims":
return echo.NewHTTPError(http.StatusBadRequest, "Invalid token format")
case !valid:
return echo.NewHTTPError(http.StatusUnauthorized, "Invalid token")
}

return c.JSON(http.StatusOK, map[string]interface{}{
"valid": true,
})
}

// TestValidateToken tests the token validation endpoint
func TestValidateToken(t *testing.T) {
// Setup
Expand All @@ -404,6 +410,47 @@ func TestValidateToken(t *testing.T) {
mockServer.AuthEnabled = true
mockServer.Security = mockSecurity

// validateToken is now defined within test scope to keep the package-level namespace clean
validateToken := func(c echo.Context, token string) (bool, error) {
// Get server from context which should contain our mock
server := c.Get("server")
if server == nil {
return false, errors.New("server not available in context")
}

// Try to use the mock server's security manager
if mockServer, ok := server.(*MockServer); ok && mockServer.Security != nil {
return mockServer.Security.ValidateToken(token)
}

return false, errors.New("validation failed")
}

// mockValidateToken is now using the common token extraction utility
mockValidateToken := func(c echo.Context) error {
token := extractTokenFromContext(c)

if token == "" {
return echo.NewHTTPError(http.StatusBadRequest, "Token is required")
}

valid, err := validateToken(c, token)

// Handle specific error types
switch {
case err != nil && err.Error() == "token expired":
return echo.NewHTTPError(http.StatusUnauthorized, "Token expired")
case err != nil && err.Error() == "missing claims":
return echo.NewHTTPError(http.StatusBadRequest, "Invalid token format")
case !valid:
return echo.NewHTTPError(http.StatusUnauthorized, "Invalid token")
}

return c.JSON(http.StatusOK, map[string]interface{}{
"valid": true,
})
}

// Test cases
testCases := []struct {
name string
Expand Down Expand Up @@ -461,6 +508,14 @@ func TestValidateToken(t *testing.T) {
expectStatus: http.StatusUnauthorized,
expectMessage: "Invalid token",
},
{
name: "Malformed JWT token",
token: "not.a.valid.jwt.token",
validateReturn: false,
validateError: errors.New("malformed token"),
expectStatus: http.StatusUnauthorized,
expectMessage: "Invalid token",
},
}

for _, tc := range testCases {
Expand All @@ -476,13 +531,21 @@ func TestValidateToken(t *testing.T) {
mockServer.On("ValidateAccessToken", tc.token).Return(tc.validateReturn).Once()
}

// Create a request
// Create a request - test both ways of providing the token
req := httptest.NewRequest(http.MethodPost, "/api/v2/auth/validate", http.NoBody)

// Randomly alternate between setting token in header vs context to test both pathways
if tc.name == "Valid token" || tc.name == "Invalid token" || tc.name == "Malformed JWT token" {
req.Header.Set("Authorization", "Bearer "+tc.token)
}

rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

// Add token to context
c.Set("token", tc.token)
// Set the token in context (for test cases not using Authorization header)
if tc.name != "Valid token" && tc.name != "Invalid token" && tc.name != "Malformed JWT token" {
c.Set("token", tc.token)
}

// Set the mock server in the context
c.Set("server", mockServer)
Expand Down
2 changes: 2 additions & 0 deletions internal/api/v2/integration.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
)

// InitializeAPI sets up the JSON API endpoints in the provided Echo instance
// The returned Controller has a Shutdown method that should be called during application shutdown
// to properly clean up resources and stop background goroutines
func InitializeAPI(
e *echo.Echo,
ds datastore.Interface,
Expand Down
26 changes: 22 additions & 4 deletions internal/api/v2/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,28 @@ func (c *Controller) AuthMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
}

if !authenticated {
// Return JSON error for API calls
return ctx.JSON(http.StatusUnauthorized, map[string]string{
"error": "Authentication required",
})
// Determine if request is from a browser or an API client
// Browsers typically include "text/html" in their Accept header
acceptHeader := ctx.Request().Header.Get("Accept")
isBrowserRequest := strings.Contains(acceptHeader, "text/html")

if isBrowserRequest {
// For browser requests, redirect to login page
loginPath := "/login"

// Optionally store the original URL for post-login redirect
originURL := ctx.Request().URL.String()
if originURL != loginPath && !strings.Contains(originURL, "login") {
loginPath += "?redirect=" + originURL
}

return ctx.Redirect(http.StatusFound, loginPath)
} else {
// For API clients, return JSON error response
return ctx.JSON(http.StatusUnauthorized, map[string]string{
"error": "Authentication required",
})
}
}

return next(ctx)
Expand Down
Loading

0 comments on commit 84ad717

Please sign in to comment.