diff --git a/docs/index.html b/docs/index.html index fa7a003..3f81a55 100644 --- a/docs/index.html +++ b/docs/index.html @@ -608,6 +608,20 @@

/stream/:namespace/*key

 GET /stream/mysite.com/visits
 ⇒ data: {"value": 36}
+
+ +

/stream/hit/:namespace/*key

+

+ Increments and streams a counter using + + Server-Sent Events (SSE) + . +
+ Optionally specify a namespace.

+ +
+GET /stream/hit/mysite.com/visits
+⇒ data: {"value": 1}
 

/create/:namespace/*key

diff --git a/main.go b/main.go index 7b7fa31..ccc3a5c 100644 --- a/main.go +++ b/main.go @@ -183,6 +183,7 @@ func CreateRouter() *gin.Engine { route.GET("/hit/:namespace/:key/shield", HitShieldView) route.GET("/hit/:namespace/:key", HitView) route.GET("/stream/:namespace/*key", middleware.SSEMiddleware(), StreamValueView) + route.GET("/stream/hit/:namespace/*key", middleware.SSEMiddleware(), StreamHitView) route.POST("/create/:namespace/*key", CreateView) route.GET("/create/:namespace/*key", CreateView) diff --git a/routes.go b/routes.go index 49e2cb0..352c53e 100644 --- a/routes.go +++ b/routes.go @@ -133,6 +133,131 @@ func StreamValueView(c *gin.Context) { }) } +func StreamHitView(c *gin.Context) { + namespace, key := utils.GetNamespaceKey(c) + if namespace == "" || key == "" { + return + } + dbKey := utils.CreateKey(c, namespace, key, false) + if dbKey == "" { // error is handled in CreateKey + return + } + // Get data from Redis + val, err := Client.Incr(context.Background(), dbKey).Result() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get data. Try again later."}) + return + } + // check if val is is greater than the max value of an int + if val > math.MaxInt { + c.JSON(http.StatusBadRequest, gin.H{"error": "Value is too large. Max value is " + strconv.Itoa(math. + MaxInt), "message": "If you are seeing this error and have a legitimate use case, please contact me @ abacus@jasoncameron.dev"}) + return + } + + go func() { + utils.SetStream(dbKey, int(val)) // #nosec G115 -- This is safe as we perform a check ( + // see above) to ensure val is within the range of an int. + Client.Expire(context.Background(), dbKey, utils.BaseTTLPeriod) + }() + + // Set SSE headers + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + + // Initialize client channel with a buffer to prevent blocking + clientChan := make(chan int, 10) + + // Create a context that's canceled when the client disconnects + ctx := c.Request.Context() + + // Add this client to the event server for this specific key + utils.ValueEventServer.NewClients <- utils.KeyClientPair{ + Key: dbKey, + Client: clientChan, + } + + // Track if cleanup has been done + var cleanupDone bool + var cleanupMutex sync.Mutex + + // Ensure client is always removed when handler exits + defer func() { + cleanupMutex.Lock() + if !cleanupDone { + cleanupDone = true + cleanupMutex.Unlock() + + // Signal the event server to remove this client + select { + case utils.ValueEventServer.ClosedClients <- utils.KeyClientPair{Key: dbKey, Client: clientChan}: + // Successfully sent cleanup signal + case <-time.After(500 * time.Millisecond): + // Timed out waiting to send cleanup signal + log.Printf("Warning: Timed out sending cleanup signal for %s", dbKey) + } + } else { + cleanupMutex.Unlock() + } + }() + + // Monitor for client disconnection in a separate goroutine + go func() { + <-ctx.Done() // Wait for context cancellation (client disconnected) + + cleanupMutex.Lock() + if !cleanupDone { + cleanupDone = true + cleanupMutex.Unlock() + + log.Printf("Client disconnected for key %s, cleaning up", dbKey) + + // Signal the event server to remove this client + select { + case utils.ValueEventServer.ClosedClients <- utils.KeyClientPair{Key: dbKey, Client: clientChan}: + // Successfully sent cleanup signal + case <-time.After(500 * time.Millisecond): + // Timed out waiting to send cleanup signal + log.Printf("Warning: Timed out sending cleanup signal for %s after disconnect", dbKey) + } + } else { + cleanupMutex.Unlock() + } + }() + + // Send initial value + if count := val; err == nil { + // Keep your exact format + _, err := c.Writer.WriteString(fmt.Sprintf("data: {\"value\":%d}\n\n", count)) + if err != nil { + log.Printf("Error writing to client: %v", err) + return + } + c.Writer.Flush() + } + + // Stream updates + c.Stream(func(w io.Writer) bool { + select { + case <-ctx.Done(): + return false + case count, ok := <-clientChan: + if !ok { + return false + } + // Keep your exact format + _, err := c.Writer.WriteString(fmt.Sprintf("data: {\"value\":%d}\n\n", count)) + if err != nil { + log.Printf("Error writing to client: %v", err) + return false + } + c.Writer.Flush() + return true + } + }) +} + func HitView(c *gin.Context) { namespace, key := utils.GetNamespaceKey(c) if namespace == "" || key == "" { diff --git a/stream_test.go b/stream_test.go index 2cd89d8..60e0b8f 100644 --- a/stream_test.go +++ b/stream_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "strings" "sync" + "sync/atomic" "testing" "time" @@ -377,3 +378,313 @@ func extractValueFromEvent(event string) int { } return data.Value } + +// Helper function to get the current counter value via GET +func getCounterValue(t *testing.T, serverURL, namespace, key string) int { + client := &http.Client{} + resp, err := client.Get(serverURL + "/get/" + namespace + "/" + key) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusOK, resp.StatusCode) + + var result struct { + Value int `json:"value"` + } + err = json.NewDecoder(resp.Body).Decode(&result) + require.NoError(t, err) + + return result.Value +} + +// TestStreamHitBasicFunctionality tests that the StreamHit endpoint correctly +// increments the counter and sends events +func TestStreamHitBasicFunctionality(t *testing.T) { + gin.SetMode(gin.TestMode) + router := setupTestRouter() + + // Create a counter first + createResp := httptest.NewRecorder() + createReq, _ := http.NewRequest("POST", "/create/test/streamhit-test", nil) + router.ServeHTTP(createResp, createReq) + assert.Equal(t, http.StatusCreated, createResp.Code) + + // For streaming tests, we need a real HTTP server + server := httptest.NewServer(router) + defer server.Close() + + // Use a real HTTP client to connect to the server + client := &http.Client{ + Timeout: 5 * time.Second, + } + + // Connect to the StreamHit endpoint + req, err := http.NewRequest("GET", server.URL+"/stream/hit/test/streamhit-test", nil) + require.NoError(t, err) + + req.Header.Set("Accept", "text/event-stream") + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Channel to collect received events + events := make(chan string, 10) + done := make(chan struct{}) + + // Process the SSE stream + go func() { + defer close(done) + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "data: ") { + select { + case events <- line: + // Event sent + case <-time.After(100 * time.Millisecond): + // Buffer full, drop event + t.Logf("Event buffer full, dropped: %s", line) + } + } + } + if err := scanner.Err(); err != nil { + t.Logf("Scanner error: %v", err) + } + }() + + // Check initial value (should be 1 since StreamHit increments on connect) + select { + case event := <-events: + assert.True(t, strings.HasPrefix(event, "data: {\"value\":")) + value := extractValueFromEvent(event) + assert.Equal(t, 1, value, "First connection should increment to 1") + case <-time.After(1 * time.Second): + t.Fatal("Timeout waiting for initial event") + } + + // Close connections + resp.Body.Close() + + // Give some time for cleanup + time.Sleep(500 * time.Millisecond) + + // Verify proper cleanup + clientCount := countClientsForKey("K:test:streamhit-test") + assert.Equal(t, 0, clientCount, "Clients weren't properly cleaned up after disconnection") +} + +// TestStreamHitMultipleClients tests multiple clients connecting to the StreamHit endpoint +func TestStreamHitMultipleClients(t *testing.T) { + gin.SetMode(gin.TestMode) + router := setupTestRouter() + + // Create a counter + createResp := httptest.NewRecorder() + createReq, _ := http.NewRequest("POST", "/create/test/streamhit-multi", nil) + router.ServeHTTP(createResp, createReq) + assert.Equal(t, http.StatusCreated, createResp.Code) + + // Start a real HTTP server + server := httptest.NewServer(router) + defer server.Close() + + // Number of clients to test + numClients := 3 + + // Set up client trackers + type clientState struct { + resp *http.Response + events chan string + done chan struct{} + lastValue int + eventCount int + } + + clients := make([]*clientState, numClients) + + // Start all clients sequentially to observe incremental counting + for i := 0; i < numClients; i++ { + // Create client state + clients[i] = &clientState{ + events: make(chan string, 10), + done: make(chan struct{}), + } + + // Create request + req, err := http.NewRequest("GET", server.URL+"/stream/hit/test/streamhit-multi", nil) + require.NoError(t, err) + req.Header.Set("Accept", "text/event-stream") + + // Connect client + client := &http.Client{ + Timeout: 5 * time.Second, + } + resp, err := client.Do(req) + require.NoError(t, err) + clients[i].resp = resp + + // Process events + go func(idx int) { + defer close(clients[idx].done) + scanner := bufio.NewScanner(clients[idx].resp.Body) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "data: ") { + select { + case clients[idx].events <- line: + // Event sent + default: + // Buffer full, drop event + } + } + } + }(i) + + // Give time for client to connect and get initial value + time.Sleep(100 * time.Millisecond) + } + + // Verify all clients receive initial values (should be sequential) + for i := 0; i < numClients; i++ { + select { + case event := <-clients[i].events: + clients[i].lastValue = extractValueFromEvent(event) + clients[i].eventCount++ + // Each client should get its connection number as initial value + assert.Equal(t, i+1, clients[i].lastValue, "Client %d received incorrect initial value", i) + case <-time.After(1 * time.Second): + t.Fatalf("Timeout waiting for client %d initial event", i) + } + } + + // Connect a new client to increment the counter again + client := &http.Client{} + req, err := http.NewRequest("GET", server.URL+"/stream/hit/test/streamhit-multi", nil) + require.NoError(t, err) + req.Header.Set("Accept", "text/event-stream") + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + // Give time for events to propagate + time.Sleep(300 * time.Millisecond) + + // Verify all clients received the update + for i := 0; i < numClients; i++ { + // Try to get the latest event + select { + case event := <-clients[i].events: + clients[i].lastValue = extractValueFromEvent(event) + clients[i].eventCount++ + case <-time.After(500 * time.Millisecond): + t.Fatalf("Client %d didn't receive update event", i) + } + + assert.Equal(t, i+1, clients[i].lastValue, "Client %d has incorrect value after new connection", i) + } + + // Disconnect clients one by one and verify cleanup + for i := 0; i < numClients; i++ { + // Close client connection + clients[i].resp.Body.Close() + + // Give time for cleanup + time.Sleep(200 * time.Millisecond) + + // Verify decreasing client count (plus the extra client we added) + expectedCount := numClients - (i + 1) + 1 // remaining clients + extra test client + clientCount := countClientsForKey("K:test:streamhit-multi") + assert.Equal(t, expectedCount, clientCount, "Client wasn't properly cleaned up after disconnection") + } + + // Clean up extra client + resp.Body.Close() + time.Sleep(200 * time.Millisecond) +} + +// TestStreamHitConcurrencyStress tests the StreamHit endpoint under high concurrency conditions +func TestStreamHitConcurrencyStress(t *testing.T) { + // Skip in normal testing as this is a long stress test + if testing.Short() { + t.Skip("Skipping stress test in short mode") + } + + gin.SetMode(gin.ReleaseMode) // Reduce logging noise + router := setupTestRouter() + + // Create a counter for stress testing + createResp := httptest.NewRecorder() + createReq, _ := http.NewRequest("POST", "/create/test/streamhit-stress", nil) + router.ServeHTTP(createResp, createReq) + require.Equal(t, http.StatusCreated, createResp.Code) + + server := httptest.NewServer(router) + defer server.Close() + + numClients := 20 + clientDuration := 300 * time.Millisecond + + // Track the current counter value (should increase with each client) + var currentValue atomic.Int64 + currentValue.Store(0) + + // Start with no clients + initialCount := countClientsForKey("K:test:streamhit-stress") + assert.Equal(t, 0, initialCount) + + // Launch many concurrent clients + var wg sync.WaitGroup + for i := 0; i < numClients; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + client := &http.Client{} + + // Create request + req, err := http.NewRequest("GET", server.URL+"/stream/hit/test/streamhit-stress", nil) + if err != nil { + t.Logf("Error creating request: %v", err) + return + } + req.Header.Set("Accept", "text/event-stream") + + // Send request + resp, err := client.Do(req) + if err != nil { + t.Logf("Error connecting: %v", err) + return + } + defer resp.Body.Close() + + // Read initial value to verify counter is working + scanner := bufio.NewScanner(resp.Body) + if scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "data: ") { + value := extractValueFromEvent(line) + if value > 0 { + currentValue.Store(int64(value)) + } + } + } + + // Keep connection open for the duration + time.Sleep(clientDuration) + }(i) + + // Stagger client creation slightly + time.Sleep(5 * time.Millisecond) + } + + wg.Wait() + + // Give extra time for any cleanup + time.Sleep(1 * time.Second) + + finalCount := countClientsForKey("K:test:streamhit-stress") + assert.Equal(t, 0, finalCount, "Not all clients were cleaned up after stress test") + + counterValue := getCounterValue(t, server.URL, "test", "streamhit-stress") + assert.Equal(t, numClients, counterValue, "Counter didn't reach expected value after all clients") +}