Skip to content

Commit 76aa760

Browse files
committed
LazyListIterable wip
1 parent 08950a8 commit 76aa760

2 files changed

Lines changed: 137 additions & 37 deletions

File tree

library/src/scala/collection/immutable/LazyListIterable.scala

Lines changed: 70 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.lang.{StringBuilder => JStringBuilder}
2222

2323
import scala.annotation.tailrec
2424
import scala.collection.generic.SerializeEnd
25+
import scala.collection.immutable.LazyListIterableBase.InRace
2526
import scala.collection.mutable.{Builder, ReusableBuilder, StringBuilder}
2627
import scala.language.implicitConversions
2728
import scala.runtime.Statics
@@ -264,7 +265,7 @@ import caps.unsafe.untrackedCaptures
264265
*/
265266
@SerialVersionUID(4L)
266267
final class LazyListIterable[+A] private (lazyState: LazyListIterable.EmptyMarker.type | (() => LazyListIterable[A]^) /* EmptyMarker.type | () => LazyListIterable[A] */)
267-
extends Iterable[A]
268+
extends LazyListIterableBase[A](if (lazyState eq LazyListIterable.EmptyMarker) null else lazyState)
268269
with collection.SeqOps[A, LazyListIterable, LazyListIterable[A]]
269270
with IterableFactoryDefaults[A, LazyListIterable]
270271
with Serializable { self: LazyListIterable[A]^ =>
@@ -275,56 +276,86 @@ final class LazyListIterable[+A] private (lazyState: LazyListIterable.EmptyMarke
275276
private def this(head: A, tail: LazyListIterable[A]^) = {
276277
this(LazyListIterable.EmptyMarker)
277278
_head = head
278-
_tail = caps.unsafe.unsafeAssumePure(tail) // SAFETY: we initialize LazyListIterable capturing tail
279+
setRawTail(caps.unsafe.unsafeAssumePure(tail)) // SAFETY: we initialize LazyListIterable capturing tail
279280
}
280281

281-
// used to synchronize lazy state evaluation
282-
// after initialization (`_head ne Uninitialized`)
282+
// `_head` and `_tail` are used to synchronize lazy state evaluation.
283+
//
284+
// initially, `_head` is `Uninitialized`. after initialization, `_head` holds:
283285
// - `null` if this is an empty lazy list
284-
// - `head: A` otherwise (can be `null`, `_tail == null` is used to test emptiness)
286+
// - `head: A` otherwise (can be the `null` value, `_tail == null` is used to test emptiness)
287+
//
288+
// `_tail` (declared in `LazyListBase`) can hold the following values:
289+
// - when `_head eq Uninitialized`
290+
// - `lazyState: () => LazyList[A]`
291+
// - while evaluating `lazyState`: the evaluating `Thread`
292+
// - if multiple threads attempt initialization: an `InRace` instance
293+
// - when `_head ne Uninitialized`
294+
// - `null` if this is an empty lazy list
295+
// - `tail: LazyList[A]` otherwise
285296
@volatile private var _head: Any /* Uninitialized | A */ =
286297
if (lazyState eq EmptyMarker) null else Uninitialized
287298

288-
// when `_head eq Uninitialized`
289-
// - `lazySate: () => LazyListIterable[A]`
290-
// - MidEvaluation while evaluating lazyState
291-
// when `_head ne Uninitialized`
292-
// - `null` if this is an empty lazy list
293-
// - `tail: LazyListIterable[A]` otherwise
294-
private var _tail: AnyRef^{this} | Null /* () => LazyListIterable[A] | MidEvaluation.type | LazyListIterable[A] | Null */ =
295-
if (lazyState eq EmptyMarker) null else lazyState
296-
297299
private def rawHead: Any = _head
298-
private def rawTail: AnyRef^{this} | Null = _tail
299300

300301
@inline private def isEvaluated: Boolean = _head.asInstanceOf[AnyRef] ne Uninitialized
301302

302-
private def initState(): Unit = synchronized {
303-
if (!isEvaluated) {
304-
// if it's already mid-evaluation, we're stuck in an infinite
305-
// self-referential loop (also it's empty)
306-
if (_tail eq MidEvaluation)
307-
throw new RuntimeException(
308-
"LazyListIterable evaluation depends on its own result (self-reference); see docs for more info")
309-
310-
val fun = _tail.asInstanceOf[() ->{this} LazyListIterable[A]^{this}]
311-
_tail = MidEvaluation
312-
val l =
313-
// `fun` returns a LazyListIterable that represents the state (head/tail) of `this`. We call `l.evaluated` to ensure
314-
// `l` is initialized, to prevent races when reading `rawTail` / `rawHead` below.
315-
// Often, lazy lists are created with `newLL(eagerCons(...))` so `l` is already initialized, but `newLL` also
316-
// accepts non-evaluated lazy lists.
317-
try fun().evaluated
318-
// restore `fun` in finally so we can try again later if an exception was thrown (similar to lazy val)
319-
finally _tail = fun
320-
_tail = l.rawTail
321-
_head = l.rawHead
303+
private def initState(): Unit = {
304+
def selfRef(): Nothing =
305+
// if it's already mid-evaluation, we're stuck in an infinite self-referential loop (also it's empty)
306+
throw new RuntimeException(
307+
"LazyList evaluation depends on its own result (self-reference); see docs for more info")
308+
309+
while (!isEvaluated) {
310+
rawTail match {
311+
case t: Thread =>
312+
if (LazyListIterableBase.isCurrentThread(t)) selfRef()
313+
val ir = InRace(t)
314+
if (_tailUpdater.compareAndSet(this, t, ir))
315+
ir.await()
316+
// loop on lost CAS
317+
318+
case ir: InRace =>
319+
if (LazyListIterableBase.isCurrentThread(ir.owner)) selfRef()
320+
ir.await()
321+
322+
case fun: Function0[_] =>
323+
// use the current thread as marker that `fun` is being evaluated.
324+
// this way, there is no allocation in the common case where there's no race.
325+
// if multiple threads attempt to initialize a LazyList, an `InRace` instance is created to coordinate.
326+
if (_tailUpdater.compareAndSet(this, fun, Thread.currentThread)) {
327+
var ex: Throwable | Null = null
328+
// `fun` returns a LazyList that represents the state (head/tail) of `this`. We call `evaluated` to ensure
329+
// the result is initialized, to prevent races when reading `rawTail` / `rawHead` below.
330+
// Often, lazy lists are created with `newLL(eagerCons(...))` so `l` is already initialized, but `newLL`
331+
// also accepts non-evaluated lazy lists.
332+
val l = try fun.asInstanceOf[() ->{this} LazyListIterable[A]^{this}].apply().evaluated catch {
333+
case t: Throwable =>
334+
ex = t
335+
null
336+
}
337+
// update `_tail` before `_head`, because `_head` is used to test `isEvaluated`
338+
val newTail = if (ex == null) l.nn.rawTail else fun
339+
val sentinel = _tailUpdater.getAndSet(this, newTail)
340+
if (ex == null) _head = l.nn.rawHead
341+
sentinel match {
342+
case ir: InRace => ir.countDown()
343+
case _ =>
344+
}
345+
if (ex != null) throw ex.nn
346+
}
347+
// loop on lost CAS
348+
349+
case _ =>
350+
// loop when _tail is a LazyList but _head is still `Uninitialized`
351+
// could call `Thread.onSpinWait()` on JDK 9+
352+
}
322353
}
323354
}
324355

325356
@tailrec private def evaluated: LazyListIterable[A]^{this} =
326357
if (isEvaluated) {
327-
if (_tail == null) Empty
358+
if (rawTail == null) Empty
328359
else this
329360
} else {
330361
initState()
@@ -1183,11 +1214,13 @@ object LazyListIterable extends IterableFactory[LazyListIterable] {
11831214
// def kount(): Unit = k += 1
11841215

11851216
private object Uninitialized extends Serializable
1186-
private object MidEvaluation
11871217
private object EmptyMarker
11881218

11891219
private val Empty: LazyListIterable[Nothing] = new LazyListIterable(EmptyMarker)
11901220

1221+
// lazy val to break cycle (Predef -> scala.package -> val LazyList -> makeTailUpdater -> Predef.classOf)
1222+
private lazy val _tailUpdater: LazyListIterableBase.TailUpdater = Empty.makeTailUpdater
1223+
11911224
/** Creates a new LazyListIterable.
11921225
*
11931226
* @tparam A the element type of the lazy list
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Scala (https://www.scala-lang.org)
3+
*
4+
* Copyright EPFL and Lightbend, Inc. dba Akka
5+
*
6+
* Licensed under Apache License 2.0
7+
* (http://www.apache.org/licenses/LICENSE-2.0).
8+
*
9+
* See the NOTICE file distributed with this work for
10+
* additional information regarding copyright ownership.
11+
*/
12+
13+
package scala.collection.immutable
14+
15+
import scala.language.`2.13`
16+
import language.experimental.captureChecking
17+
18+
import java.util.concurrent.CountDownLatch
19+
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater
20+
21+
/**
22+
* Base class for [[LazyList]] to split out code that uses concurrency utilities that are not available
23+
* on Scala.js. This way, Scala.js does not need to override all of LazyList.
24+
*
25+
* This class cannot be a trait because `AtomicReferenceFieldUpdater.newUpdater` checks if the caller
26+
* class has access to the corresponding field. So it needs to be called in the class where the field is
27+
* declared (fields are always private in Scala).
28+
*/
29+
abstract class LazyListIterableBase[+A] private[immutable] (initialTail: AnyRef | Null) extends Iterable[A] with Serializable {
30+
/** See [[LazyList._head]] for the possible states of this field. */
31+
@volatile private var _tail: AnyRef^{this} | Null /* () => LazyList[A] | Thread | InRace | LazyList[A] | Null */ = initialTail
32+
33+
private[immutable] def rawTail: AnyRef^{this} | Null = _tail
34+
35+
private[immutable] def setRawTail(value: AnyRef): Unit = _tail = value
36+
37+
@noinline private[immutable] def makeTailUpdater: LazyListIterableBase.TailUpdater =
38+
new LazyListIterableBase.TailUpdater(AtomicReferenceFieldUpdater.newUpdater(classOf[LazyListIterableBase[?]], classOf[AnyRef], "_tail"))
39+
}
40+
41+
private[immutable] object LazyListIterableBase {
42+
final class TailUpdater(u: AtomicReferenceFieldUpdater[LazyListIterableBase[?], AnyRef]) {
43+
def compareAndSet(ll: LazyListIterableBase[?], expected: AnyRef, value: AnyRef): Boolean = u.compareAndSet(ll, expected, value)
44+
def getAndSet(ll: LazyListIterableBase[?], value: AnyRef | Null): AnyRef | Null = u.getAndSet(ll, value)
45+
}
46+
47+
// this utility is constant `true` on Scala.js -> enables DCE in LazyList
48+
def isCurrentThread(t: Thread): Boolean = t eq Thread.currentThread
49+
// also for Scala.js
50+
def InRace(t: Thread): InRace = new InRace(t)
51+
52+
final class InRace private[LazyListIterableBase] (val owner: Thread) {
53+
private val done: CountDownLatch = new CountDownLatch(1)
54+
55+
def await(): Unit = {
56+
var interrupted = false
57+
while (done.getCount > 0) {
58+
try done.await() catch {
59+
case _: InterruptedException => interrupted = true
60+
}
61+
}
62+
if (interrupted) Thread.currentThread().interrupt()
63+
}
64+
65+
def countDown(): Unit = done.countDown()
66+
}
67+
}

0 commit comments

Comments
 (0)