Skip to content

Commit ae024bb

Browse files
authored
Fix race conditions in pubsub (#966)
Rewrote the logic to: - Maintain separate AtomicCell state maps for channels and patterns. - Create a new subscription, topic and Dispatcher first time someone subscribes to channel or pattern. - In case someone subscribes to the same thing again the subscriber count is increased. - Subscriber count is decreased when the Stream is terminated. - unsubscribe finishes all Streams. - Cleanup is performed when the last stream terminates. Also changed the return type of publish to match the type from Lettuce.
1 parent 63532d6 commit ae024bb

File tree

13 files changed

+388
-146
lines changed

13 files changed

+388
-146
lines changed

modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/PubSub.scala

+3-4
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@ import cats.syntax.all._
2323
import dev.profunktor.redis4cats.connection.RedisClient
2424
import dev.profunktor.redis4cats.data._
2525
import dev.profunktor.redis4cats.effect._
26-
import dev.profunktor.redis4cats.pubsub.internals.{ LivePubSubCommands, Publisher, Subscriber }
26+
import dev.profunktor.redis4cats.pubsub.internals.{ LivePubSubCommands, PubSubState, Publisher, Subscriber }
2727
import fs2.Stream
28-
import dev.profunktor.redis4cats.pubsub.internals.PubSubState
2928
import io.lettuce.core.pubsub.StatefulRedisPubSubConnection
3029

3130
object PubSub {
@@ -58,7 +57,7 @@ object PubSub {
5857
val (acquire, release) = acquireAndRelease[F, K, V](client, codec)
5958
// One exclusive connection for subscriptions and another connection for publishing / stats
6059
for {
61-
state <- Resource.eval(Ref.of[F, PubSubState[F, K, V]](PubSubState(Map.empty, Map.empty)))
60+
state <- Resource.eval(PubSubState.make[F, K, V])
6261
sConn <- Resource.make(acquire)(release)
6362
pConn <- Resource.make(acquire)(release)
6463
} yield new LivePubSubCommands[F, K, V](state, sConn, pConn)
@@ -88,7 +87,7 @@ object PubSub {
8887
): Resource[F, SubscribeCommands[F, Stream[F, *], K, V]] = {
8988
val (acquire, release) = acquireAndRelease[F, K, V](client, codec)
9089
for {
91-
state <- Resource.eval(Ref.of[F, PubSubState[F, K, V]](PubSubState(Map.empty, Map.empty)))
90+
state <- Resource.eval(PubSubState.make[F, K, V])
9291
conn <- Resource.make(acquire)(release)
9392
} yield new Subscriber(state, conn)
9493
}

modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/PubSubCommands.scala

+29-2
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,12 @@ trait PubSubStats[F[_], K] {
3737
* @tparam V the value type
3838
*/
3939
trait PublishCommands[F[_], S[_], K, V] extends PubSubStats[F, K] {
40-
def publish(channel: RedisChannel[K]): S[V] => S[Unit]
41-
def publish(channel: RedisChannel[K], value: V): F[Unit]
40+
41+
/** @return The number of clients that received the message. */
42+
def publish(channel: RedisChannel[K]): S[V] => S[Long]
43+
44+
/** @return The number of clients that received the message. */
45+
def publish(channel: RedisChannel[K], value: V): F[Long]
4246
}
4347

4448
/**
@@ -51,17 +55,40 @@ trait SubscribeCommands[F[_], S[_], K, V] {
5155

5256
/**
5357
* Subscribes to a channel.
58+
*
59+
* @note If you invoke `subscribe` multiple times for the same channel, we will not call 'SUBSCRIBE' in Redis multiple
60+
* times but instead will return a stream that will use the existing subscription to that channel. The underlying
61+
* subscription is cleaned up when all the streams terminate or when `unsubscribe` is invoked.
5462
*/
5563
def subscribe(channel: RedisChannel[K]): S[V]
5664

65+
/** Terminates all streams that are subscribed to the channel. */
5766
def unsubscribe(channel: RedisChannel[K]): F[Unit]
5867

5968
/**
6069
* Subscribes to a pattern.
70+
*
71+
* @note If you invoke `subscribe` multiple times for the same pattern, we will not call 'SUBSCRIBE' in Redis multiple
72+
* times but instead will return a stream that will use the existing subscription to that pattern. The underlying
73+
* subscription is cleaned up when all the streams terminate or when `unsubscribe` is invoked.
6174
*/
6275
def psubscribe(channel: RedisPattern[K]): S[RedisPatternEvent[K, V]]
6376

77+
/** Terminates all streams that are subscribed to the pattern. */
6478
def punsubscribe(channel: RedisPattern[K]): F[Unit]
79+
80+
/** Returns the channel subscriptions that the library keeps of.
81+
*
82+
* @return how many streams are subscribed to each channel.
83+
* @see [[SubscribeCommands.subscribe]] for more information.
84+
* */
85+
def internalChannelSubscriptions: F[Map[RedisChannel[K], Long]]
86+
87+
/** Returns the pattern subscriptions that the library keeps of.
88+
*
89+
* @return how many streams are subscribed to each pattern.
90+
* @see [[SubscribeCommands.psubscribe]] for more information. */
91+
def internalPatternSubscriptions: F[Map[RedisPattern[K], Long]]
6592
}
6693

6794
/**

modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/LivePubSubCommands.scala

+10-7
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import fs2.Stream
2929
import io.lettuce.core.pubsub.StatefulRedisPubSubConnection
3030

3131
private[pubsub] class LivePubSubCommands[F[_]: Async: Log, K, V](
32-
state: Ref[F, PubSubState[F, K, V]],
32+
state: PubSubState[F, K, V],
3333
subConnection: StatefulRedisPubSubConnection[K, V],
3434
pubConnection: StatefulRedisPubSubConnection[K, V]
3535
) extends PubSubCommands[F, Stream[F, *], K, V] {
@@ -50,13 +50,17 @@ private[pubsub] class LivePubSubCommands[F[_]: Async: Log, K, V](
5050
override def punsubscribe(pattern: RedisPattern[K]): F[Unit] =
5151
subCommands.punsubscribe(pattern)
5252

53-
override def publish(channel: RedisChannel[K]): Stream[F, V] => Stream[F, Unit] =
53+
override def internalChannelSubscriptions: F[Map[RedisChannel[K], Long]] =
54+
subCommands.internalChannelSubscriptions
55+
56+
override def internalPatternSubscriptions: F[Map[RedisPattern[K], Long]] =
57+
subCommands.internalPatternSubscriptions
58+
59+
override def publish(channel: RedisChannel[K]): Stream[F, V] => Stream[F, Long] =
5460
_.evalMap(publish(channel, _))
5561

56-
override def publish(channel: RedisChannel[K], message: V): F[Unit] = {
57-
val resource = Resource.eval(state.get) >>= PubSubInternals.channel[F, K, V](state, subConnection).apply(channel)
58-
resource.use(_ => FutureLift[F].lift(pubConnection.async().publish(channel.underlying, message)).void)
59-
}
62+
override def publish(channel: RedisChannel[K], message: V): F[Long] =
63+
FutureLift[F].lift(pubConnection.async().publish(channel.underlying, message)).map(l => l: Long)
6064

6165
override def numPat: F[Long] =
6266
pubSubStats.numPat
@@ -78,5 +82,4 @@ private[pubsub] class LivePubSubCommands[F[_]: Async: Log, K, V](
7882

7983
override def shardNumSub(channels: List[RedisChannel[K]]): F[List[Subscription[K]]] =
8084
pubSubStats.shardNumSub(channels)
81-
8285
}

modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/LivePubSubStats.scala

-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ private[pubsub] class LivePubSubStats[F[_]: FlatMap: FutureLift, K, V](
6363
FutureLift[F]
6464
.lift(pubConnection.async().pubsubShardNumsub(channels.map(_.underlying): _*))
6565
.map(toSubscription[K])
66-
6766
}
6867
object LivePubSubStats {
6968
private def toSubscription[K](map: ju.Map[K, JLong]): List[Subscription[K]] =

modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/PubSubInternals.scala

+8-59
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,26 @@
1717
package dev.profunktor.redis4cats.pubsub.internals
1818

1919
import scala.util.control.NoStackTrace
20-
21-
import cats.effect.kernel.{ Async, Ref, Resource, Sync }
22-
import cats.effect.std.Dispatcher
23-
import cats.syntax.all._
20+
import cats.effect.std.{ Dispatcher }
2421
import dev.profunktor.redis4cats.data.RedisChannel
2522
import dev.profunktor.redis4cats.data.RedisPattern
2623
import dev.profunktor.redis4cats.data.RedisPatternEvent
27-
import dev.profunktor.redis4cats.effect.Log
28-
import fs2.concurrent.Topic
29-
import io.lettuce.core.pubsub.{ RedisPubSubListener, StatefulRedisPubSubConnection }
24+
import io.lettuce.core.pubsub.{ RedisPubSubListener }
3025
import io.lettuce.core.pubsub.RedisPubSubAdapter
3126

3227
object PubSubInternals {
3328
case class DispatcherAlreadyShutdown() extends NoStackTrace
3429

35-
private[redis4cats] def channelListener[F[_]: Async, K, V](
30+
private[redis4cats] def channelListener[F[_], K, V](
3631
channel: RedisChannel[K],
37-
topic: Topic[F, Option[V]],
32+
publish: V => F[Unit],
3833
dispatcher: Dispatcher[F]
3934
): RedisPubSubListener[K, V] =
4035
new RedisPubSubAdapter[K, V] {
4136
override def message(ch: K, msg: V): Unit =
4237
if (ch == channel.underlying) {
4338
try {
44-
dispatcher.unsafeRunSync(topic.publish1(Option(msg)).void)
39+
dispatcher.unsafeRunSync(publish(msg))
4540
} catch {
4641
case _: IllegalStateException => throw DispatcherAlreadyShutdown()
4742
}
@@ -50,65 +45,19 @@ object PubSubInternals {
5045
// Do not uncomment this, as if you will do this the channel listener will get a message twice
5146
// override def message(pattern: K, channel: K, message: V): Unit = {}
5247
}
53-
private[redis4cats] def patternListener[F[_]: Async, K, V](
48+
private[redis4cats] def patternListener[F[_], K, V](
5449
redisPattern: RedisPattern[K],
55-
topic: Topic[F, Option[RedisPatternEvent[K, V]]],
50+
publish: RedisPatternEvent[K, V] => F[Unit],
5651
dispatcher: Dispatcher[F]
5752
): RedisPubSubListener[K, V] =
5853
new RedisPubSubAdapter[K, V] {
5954
override def message(pattern: K, channel: K, message: V): Unit =
6055
if (pattern == redisPattern.underlying) {
6156
try {
62-
dispatcher.unsafeRunSync(topic.publish1(Option(RedisPatternEvent(pattern, channel, message))).void)
57+
dispatcher.unsafeRunSync(publish(RedisPatternEvent(pattern, channel, message)))
6358
} catch {
6459
case _: IllegalStateException => throw DispatcherAlreadyShutdown()
6560
}
6661
}
6762
}
68-
69-
private[redis4cats] def channel[F[_]: Async: Log, K, V](
70-
state: Ref[F, PubSubState[F, K, V]],
71-
subConnection: StatefulRedisPubSubConnection[K, V]
72-
): GetOrCreateTopicListener[F, K, V] = { channel => st =>
73-
st.channels
74-
.get(channel.underlying)
75-
.fold {
76-
for {
77-
dispatcher <- Dispatcher.parallel[F]
78-
topic <- Resource.eval(Topic[F, Option[V]])
79-
_ <- Resource.eval(Log[F].info(s"Creating listener for channel: $channel"))
80-
listener = channelListener(channel, topic, dispatcher)
81-
_ <- Resource.make {
82-
Sync[F].delay(subConnection.addListener(listener)) *>
83-
state.update(s => s.copy(channels = s.channels.updated(channel.underlying, topic)))
84-
} { _ =>
85-
Sync[F].delay(subConnection.removeListener(listener)) *>
86-
state.update(s => s.copy(channels = s.channels - channel.underlying))
87-
}
88-
} yield topic
89-
}(Resource.pure)
90-
}
91-
92-
private[redis4cats] def pattern[F[_]: Async: Log, K, V](
93-
state: Ref[F, PubSubState[F, K, V]],
94-
subConnection: StatefulRedisPubSubConnection[K, V]
95-
): GetOrCreatePatternListener[F, K, V] = { channel => st =>
96-
st.patterns
97-
.get(channel.underlying)
98-
.fold {
99-
for {
100-
dispatcher <- Dispatcher.parallel[F]
101-
topic <- Resource.eval(Topic[F, Option[RedisPatternEvent[K, V]]])
102-
_ <- Resource.eval(Log[F].info(s"Creating listener for pattern: $channel"))
103-
listener = patternListener(channel, topic, dispatcher)
104-
_ <- Resource.make {
105-
Sync[F].delay(subConnection.addListener(listener)) *>
106-
state.update(s => s.copy(patterns = s.patterns.updated(channel.underlying, topic)))
107-
} { _ =>
108-
Sync[F].delay(subConnection.removeListener(listener)) *>
109-
state.update(s => s.copy(patterns = s.patterns - channel.underlying))
110-
}
111-
} yield topic
112-
}(Resource.pure)
113-
}
11463
}

modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/PubSubState.scala

+16-5
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,21 @@
1616

1717
package dev.profunktor.redis4cats.pubsub.internals
1818

19-
import dev.profunktor.redis4cats.data.RedisPatternEvent
20-
import fs2.concurrent.Topic
19+
import cats.syntax.all._
20+
import cats.effect.kernel.Concurrent
21+
import cats.effect.std.AtomicCell
22+
import dev.profunktor.redis4cats.data.{ RedisChannel, RedisPattern, RedisPatternEvent }
2123

22-
final case class PubSubState[F[_], K, V](
23-
channels: Map[K, Topic[F, Option[V]]],
24-
patterns: Map[K, Topic[F, Option[RedisPatternEvent[K, V]]]]
24+
/** We use `AtomicCell` instead of `Ref` because we need locking while side-effecting. */
25+
case class PubSubState[F[_], K, V](
26+
channelSubs: AtomicCell[F, Map[RedisChannel[K], Redis4CatsSubscription[F, V]]],
27+
patternSubs: AtomicCell[F, Map[RedisPattern[K], Redis4CatsSubscription[F, RedisPatternEvent[K, V]]]]
2528
)
29+
object PubSubState {
30+
def make[F[_]: Concurrent, K, V]: F[PubSubState[F, K, V]] =
31+
for {
32+
channelSubs <- AtomicCell[F].of(Map.empty[RedisChannel[K], Redis4CatsSubscription[F, V]])
33+
patternSubs <- AtomicCell[F].of(Map.empty[RedisPattern[K], Redis4CatsSubscription[F, RedisPatternEvent[K, V]]])
34+
} yield apply(channelSubs, patternSubs)
35+
36+
}

modules/streams/src/main/scala/dev/profunktor/redis4cats/pubsub/internals/Publisher.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ private[pubsub] class Publisher[F[_]: FlatMap: FutureLift, K, V](
3232

3333
private[redis4cats] val pubSubStats: PubSubStats[F, K] = new LivePubSubStats(pubConnection)
3434

35-
override def publish(channel: RedisChannel[K]): Stream[F, V] => Stream[F, Unit] =
35+
override def publish(channel: RedisChannel[K]): Stream[F, V] => Stream[F, Long] =
3636
_.evalMap(publish(channel, _))
3737

38-
override def publish(channel: RedisChannel[K], message: V): F[Unit] =
39-
FutureLift[F].lift(pubConnection.async().publish(channel.underlying, message)).void
38+
override def publish(channel: RedisChannel[K], message: V): F[Long] =
39+
FutureLift[F].lift(pubConnection.async().publish(channel.underlying, message)).map(l => l: Long)
4040

4141
override def pubSubChannels: F[List[RedisChannel[K]]] =
4242
pubSubStats.pubSubChannels
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright 2018-2021 ProfunKtor
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package dev.profunktor.redis4cats.pubsub.internals
18+
19+
import cats.Applicative
20+
import fs2.concurrent.Topic
21+
22+
/**
23+
* Stores an ongoing subscription.
24+
*
25+
* @param topic single-publisher, multiple-subscribers. The same topic is reused if `subscribe` is invoked more than
26+
* once. The subscribers' streams are terminated when `None` is published.
27+
* @param subscribers subscriber count, when `subscribers` reaches 0 `cleanup` is called and `None` is published
28+
* to the topic.
29+
*/
30+
final private[redis4cats] case class Redis4CatsSubscription[F[_], V](
31+
topic: Topic[F, Option[V]],
32+
subscribers: Long,
33+
cleanup: F[Unit]
34+
) {
35+
assert(subscribers > 0, s"subscribers must be > 0, was $subscribers")
36+
37+
def addSubscriber: Redis4CatsSubscription[F, V] = copy(subscribers = subscribers + 1)
38+
def removeSubscriber: Redis4CatsSubscription[F, V] = copy(subscribers = subscribers - 1)
39+
def isLastSubscriber: Boolean = subscribers == 1
40+
41+
def stream(onTermination: F[Unit])(implicit F: Applicative[F]): fs2.Stream[F, V] =
42+
topic.subscribe(500).unNoneTerminate.onFinalize(onTermination)
43+
}

0 commit comments

Comments
 (0)