diff --git a/docs/index.html b/docs/index.html
index fa7a003..3f81a55 100644
--- a/docs/index.html
+++ b/docs/index.html
@@ -608,6 +608,20 @@
/stream/:namespace/*key
GET /stream/mysite.com/visits
⇒ data: {"value": 36}
+
+
+ /stream/hit/:namespace/*key
+
+ Increments and streams a counter using
+
+ Server-Sent Events (SSE)
+ .
+
+ Optionally specify a namespace.
+
+
+GET /stream/hit/mysite.com/visits
+⇒ data: {"value": 1}
/create/:namespace/*key
diff --git a/main.go b/main.go
index 7b7fa31..ccc3a5c 100644
--- a/main.go
+++ b/main.go
@@ -183,6 +183,7 @@ func CreateRouter() *gin.Engine {
route.GET("/hit/:namespace/:key/shield", HitShieldView)
route.GET("/hit/:namespace/:key", HitView)
route.GET("/stream/:namespace/*key", middleware.SSEMiddleware(), StreamValueView)
+ route.GET("/stream/hit/:namespace/*key", middleware.SSEMiddleware(), StreamHitView)
route.POST("/create/:namespace/*key", CreateView)
route.GET("/create/:namespace/*key", CreateView)
diff --git a/routes.go b/routes.go
index 49e2cb0..352c53e 100644
--- a/routes.go
+++ b/routes.go
@@ -133,6 +133,131 @@ func StreamValueView(c *gin.Context) {
})
}
+func StreamHitView(c *gin.Context) {
+ namespace, key := utils.GetNamespaceKey(c)
+ if namespace == "" || key == "" {
+ return
+ }
+ dbKey := utils.CreateKey(c, namespace, key, false)
+ if dbKey == "" { // error is handled in CreateKey
+ return
+ }
+ // Get data from Redis
+ val, err := Client.Incr(context.Background(), dbKey).Result()
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get data. Try again later."})
+ return
+ }
+ // check if val is is greater than the max value of an int
+ if val > math.MaxInt {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "Value is too large. Max value is " + strconv.Itoa(math.
+ MaxInt), "message": "If you are seeing this error and have a legitimate use case, please contact me @ abacus@jasoncameron.dev"})
+ return
+ }
+
+ go func() {
+ utils.SetStream(dbKey, int(val)) // #nosec G115 -- This is safe as we perform a check (
+ // see above) to ensure val is within the range of an int.
+ Client.Expire(context.Background(), dbKey, utils.BaseTTLPeriod)
+ }()
+
+ // Set SSE headers
+ c.Header("Content-Type", "text/event-stream")
+ c.Header("Cache-Control", "no-cache")
+ c.Header("Connection", "keep-alive")
+
+ // Initialize client channel with a buffer to prevent blocking
+ clientChan := make(chan int, 10)
+
+ // Create a context that's canceled when the client disconnects
+ ctx := c.Request.Context()
+
+ // Add this client to the event server for this specific key
+ utils.ValueEventServer.NewClients <- utils.KeyClientPair{
+ Key: dbKey,
+ Client: clientChan,
+ }
+
+ // Track if cleanup has been done
+ var cleanupDone bool
+ var cleanupMutex sync.Mutex
+
+ // Ensure client is always removed when handler exits
+ defer func() {
+ cleanupMutex.Lock()
+ if !cleanupDone {
+ cleanupDone = true
+ cleanupMutex.Unlock()
+
+ // Signal the event server to remove this client
+ select {
+ case utils.ValueEventServer.ClosedClients <- utils.KeyClientPair{Key: dbKey, Client: clientChan}:
+ // Successfully sent cleanup signal
+ case <-time.After(500 * time.Millisecond):
+ // Timed out waiting to send cleanup signal
+ log.Printf("Warning: Timed out sending cleanup signal for %s", dbKey)
+ }
+ } else {
+ cleanupMutex.Unlock()
+ }
+ }()
+
+ // Monitor for client disconnection in a separate goroutine
+ go func() {
+ <-ctx.Done() // Wait for context cancellation (client disconnected)
+
+ cleanupMutex.Lock()
+ if !cleanupDone {
+ cleanupDone = true
+ cleanupMutex.Unlock()
+
+ log.Printf("Client disconnected for key %s, cleaning up", dbKey)
+
+ // Signal the event server to remove this client
+ select {
+ case utils.ValueEventServer.ClosedClients <- utils.KeyClientPair{Key: dbKey, Client: clientChan}:
+ // Successfully sent cleanup signal
+ case <-time.After(500 * time.Millisecond):
+ // Timed out waiting to send cleanup signal
+ log.Printf("Warning: Timed out sending cleanup signal for %s after disconnect", dbKey)
+ }
+ } else {
+ cleanupMutex.Unlock()
+ }
+ }()
+
+ // Send initial value
+ if count := val; err == nil {
+ // Keep your exact format
+ _, err := c.Writer.WriteString(fmt.Sprintf("data: {\"value\":%d}\n\n", count))
+ if err != nil {
+ log.Printf("Error writing to client: %v", err)
+ return
+ }
+ c.Writer.Flush()
+ }
+
+ // Stream updates
+ c.Stream(func(w io.Writer) bool {
+ select {
+ case <-ctx.Done():
+ return false
+ case count, ok := <-clientChan:
+ if !ok {
+ return false
+ }
+ // Keep your exact format
+ _, err := c.Writer.WriteString(fmt.Sprintf("data: {\"value\":%d}\n\n", count))
+ if err != nil {
+ log.Printf("Error writing to client: %v", err)
+ return false
+ }
+ c.Writer.Flush()
+ return true
+ }
+ })
+}
+
func HitView(c *gin.Context) {
namespace, key := utils.GetNamespaceKey(c)
if namespace == "" || key == "" {
diff --git a/stream_test.go b/stream_test.go
index 2cd89d8..60e0b8f 100644
--- a/stream_test.go
+++ b/stream_test.go
@@ -7,6 +7,7 @@ import (
"net/http/httptest"
"strings"
"sync"
+ "sync/atomic"
"testing"
"time"
@@ -377,3 +378,313 @@ func extractValueFromEvent(event string) int {
}
return data.Value
}
+
+// Helper function to get the current counter value via GET
+func getCounterValue(t *testing.T, serverURL, namespace, key string) int {
+ client := &http.Client{}
+ resp, err := client.Get(serverURL + "/get/" + namespace + "/" + key)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ require.Equal(t, http.StatusOK, resp.StatusCode)
+
+ var result struct {
+ Value int `json:"value"`
+ }
+ err = json.NewDecoder(resp.Body).Decode(&result)
+ require.NoError(t, err)
+
+ return result.Value
+}
+
+// TestStreamHitBasicFunctionality tests that the StreamHit endpoint correctly
+// increments the counter and sends events
+func TestStreamHitBasicFunctionality(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ router := setupTestRouter()
+
+ // Create a counter first
+ createResp := httptest.NewRecorder()
+ createReq, _ := http.NewRequest("POST", "/create/test/streamhit-test", nil)
+ router.ServeHTTP(createResp, createReq)
+ assert.Equal(t, http.StatusCreated, createResp.Code)
+
+ // For streaming tests, we need a real HTTP server
+ server := httptest.NewServer(router)
+ defer server.Close()
+
+ // Use a real HTTP client to connect to the server
+ client := &http.Client{
+ Timeout: 5 * time.Second,
+ }
+
+ // Connect to the StreamHit endpoint
+ req, err := http.NewRequest("GET", server.URL+"/stream/hit/test/streamhit-test", nil)
+ require.NoError(t, err)
+
+ req.Header.Set("Accept", "text/event-stream")
+ resp, err := client.Do(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ // Channel to collect received events
+ events := make(chan string, 10)
+ done := make(chan struct{})
+
+ // Process the SSE stream
+ go func() {
+ defer close(done)
+ scanner := bufio.NewScanner(resp.Body)
+ for scanner.Scan() {
+ line := scanner.Text()
+ if strings.HasPrefix(line, "data: ") {
+ select {
+ case events <- line:
+ // Event sent
+ case <-time.After(100 * time.Millisecond):
+ // Buffer full, drop event
+ t.Logf("Event buffer full, dropped: %s", line)
+ }
+ }
+ }
+ if err := scanner.Err(); err != nil {
+ t.Logf("Scanner error: %v", err)
+ }
+ }()
+
+ // Check initial value (should be 1 since StreamHit increments on connect)
+ select {
+ case event := <-events:
+ assert.True(t, strings.HasPrefix(event, "data: {\"value\":"))
+ value := extractValueFromEvent(event)
+ assert.Equal(t, 1, value, "First connection should increment to 1")
+ case <-time.After(1 * time.Second):
+ t.Fatal("Timeout waiting for initial event")
+ }
+
+ // Close connections
+ resp.Body.Close()
+
+ // Give some time for cleanup
+ time.Sleep(500 * time.Millisecond)
+
+ // Verify proper cleanup
+ clientCount := countClientsForKey("K:test:streamhit-test")
+ assert.Equal(t, 0, clientCount, "Clients weren't properly cleaned up after disconnection")
+}
+
+// TestStreamHitMultipleClients tests multiple clients connecting to the StreamHit endpoint
+func TestStreamHitMultipleClients(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ router := setupTestRouter()
+
+ // Create a counter
+ createResp := httptest.NewRecorder()
+ createReq, _ := http.NewRequest("POST", "/create/test/streamhit-multi", nil)
+ router.ServeHTTP(createResp, createReq)
+ assert.Equal(t, http.StatusCreated, createResp.Code)
+
+ // Start a real HTTP server
+ server := httptest.NewServer(router)
+ defer server.Close()
+
+ // Number of clients to test
+ numClients := 3
+
+ // Set up client trackers
+ type clientState struct {
+ resp *http.Response
+ events chan string
+ done chan struct{}
+ lastValue int
+ eventCount int
+ }
+
+ clients := make([]*clientState, numClients)
+
+ // Start all clients sequentially to observe incremental counting
+ for i := 0; i < numClients; i++ {
+ // Create client state
+ clients[i] = &clientState{
+ events: make(chan string, 10),
+ done: make(chan struct{}),
+ }
+
+ // Create request
+ req, err := http.NewRequest("GET", server.URL+"/stream/hit/test/streamhit-multi", nil)
+ require.NoError(t, err)
+ req.Header.Set("Accept", "text/event-stream")
+
+ // Connect client
+ client := &http.Client{
+ Timeout: 5 * time.Second,
+ }
+ resp, err := client.Do(req)
+ require.NoError(t, err)
+ clients[i].resp = resp
+
+ // Process events
+ go func(idx int) {
+ defer close(clients[idx].done)
+ scanner := bufio.NewScanner(clients[idx].resp.Body)
+ for scanner.Scan() {
+ line := scanner.Text()
+ if strings.HasPrefix(line, "data: ") {
+ select {
+ case clients[idx].events <- line:
+ // Event sent
+ default:
+ // Buffer full, drop event
+ }
+ }
+ }
+ }(i)
+
+ // Give time for client to connect and get initial value
+ time.Sleep(100 * time.Millisecond)
+ }
+
+ // Verify all clients receive initial values (should be sequential)
+ for i := 0; i < numClients; i++ {
+ select {
+ case event := <-clients[i].events:
+ clients[i].lastValue = extractValueFromEvent(event)
+ clients[i].eventCount++
+ // Each client should get its connection number as initial value
+ assert.Equal(t, i+1, clients[i].lastValue, "Client %d received incorrect initial value", i)
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timeout waiting for client %d initial event", i)
+ }
+ }
+
+ // Connect a new client to increment the counter again
+ client := &http.Client{}
+ req, err := http.NewRequest("GET", server.URL+"/stream/hit/test/streamhit-multi", nil)
+ require.NoError(t, err)
+ req.Header.Set("Accept", "text/event-stream")
+ resp, err := client.Do(req)
+ require.NoError(t, err)
+ defer resp.Body.Close()
+
+ // Give time for events to propagate
+ time.Sleep(300 * time.Millisecond)
+
+ // Verify all clients received the update
+ for i := 0; i < numClients; i++ {
+ // Try to get the latest event
+ select {
+ case event := <-clients[i].events:
+ clients[i].lastValue = extractValueFromEvent(event)
+ clients[i].eventCount++
+ case <-time.After(500 * time.Millisecond):
+ t.Fatalf("Client %d didn't receive update event", i)
+ }
+
+ assert.Equal(t, i+1, clients[i].lastValue, "Client %d has incorrect value after new connection", i)
+ }
+
+ // Disconnect clients one by one and verify cleanup
+ for i := 0; i < numClients; i++ {
+ // Close client connection
+ clients[i].resp.Body.Close()
+
+ // Give time for cleanup
+ time.Sleep(200 * time.Millisecond)
+
+ // Verify decreasing client count (plus the extra client we added)
+ expectedCount := numClients - (i + 1) + 1 // remaining clients + extra test client
+ clientCount := countClientsForKey("K:test:streamhit-multi")
+ assert.Equal(t, expectedCount, clientCount, "Client wasn't properly cleaned up after disconnection")
+ }
+
+ // Clean up extra client
+ resp.Body.Close()
+ time.Sleep(200 * time.Millisecond)
+}
+
+// TestStreamHitConcurrencyStress tests the StreamHit endpoint under high concurrency conditions
+func TestStreamHitConcurrencyStress(t *testing.T) {
+ // Skip in normal testing as this is a long stress test
+ if testing.Short() {
+ t.Skip("Skipping stress test in short mode")
+ }
+
+ gin.SetMode(gin.ReleaseMode) // Reduce logging noise
+ router := setupTestRouter()
+
+ // Create a counter for stress testing
+ createResp := httptest.NewRecorder()
+ createReq, _ := http.NewRequest("POST", "/create/test/streamhit-stress", nil)
+ router.ServeHTTP(createResp, createReq)
+ require.Equal(t, http.StatusCreated, createResp.Code)
+
+ server := httptest.NewServer(router)
+ defer server.Close()
+
+ numClients := 20
+ clientDuration := 300 * time.Millisecond
+
+ // Track the current counter value (should increase with each client)
+ var currentValue atomic.Int64
+ currentValue.Store(0)
+
+ // Start with no clients
+ initialCount := countClientsForKey("K:test:streamhit-stress")
+ assert.Equal(t, 0, initialCount)
+
+ // Launch many concurrent clients
+ var wg sync.WaitGroup
+ for i := 0; i < numClients; i++ {
+ wg.Add(1)
+ go func(idx int) {
+ defer wg.Done()
+
+ client := &http.Client{}
+
+ // Create request
+ req, err := http.NewRequest("GET", server.URL+"/stream/hit/test/streamhit-stress", nil)
+ if err != nil {
+ t.Logf("Error creating request: %v", err)
+ return
+ }
+ req.Header.Set("Accept", "text/event-stream")
+
+ // Send request
+ resp, err := client.Do(req)
+ if err != nil {
+ t.Logf("Error connecting: %v", err)
+ return
+ }
+ defer resp.Body.Close()
+
+ // Read initial value to verify counter is working
+ scanner := bufio.NewScanner(resp.Body)
+ if scanner.Scan() {
+ line := scanner.Text()
+ if strings.HasPrefix(line, "data: ") {
+ value := extractValueFromEvent(line)
+ if value > 0 {
+ currentValue.Store(int64(value))
+ }
+ }
+ }
+
+ // Keep connection open for the duration
+ time.Sleep(clientDuration)
+ }(i)
+
+ // Stagger client creation slightly
+ time.Sleep(5 * time.Millisecond)
+ }
+
+ wg.Wait()
+
+ // Give extra time for any cleanup
+ time.Sleep(1 * time.Second)
+
+ finalCount := countClientsForKey("K:test:streamhit-stress")
+ assert.Equal(t, 0, finalCount, "Not all clients were cleaned up after stress test")
+
+ counterValue := getCounterValue(t, server.URL, "test", "streamhit-stress")
+ assert.Equal(t, numClients, counterValue, "Counter didn't reach expected value after all clients")
+}