Skip to content

Commit f0f4f69

Browse files
authored
Merge pull request scala/scala#9174 from NthPortal/topic/mutation-tracking-iterators/PR
[bug#12009] Make ListBuffer's iterator fail-fast
2 parents 3884ace + 0dc238a commit f0f4f69

File tree

4 files changed

+172
-33
lines changed

4 files changed

+172
-33
lines changed

library/src/scala/collection/mutable/Growable.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ package scala
1414
package collection
1515
package mutable
1616

17-
import scala.collection.IterableOnce
17+
import scala.annotation.nowarn
1818

1919
/** This trait forms part of collections that can be augmented
2020
* using a `+=` operator and that can be cleared of all elements using
@@ -56,10 +56,14 @@ trait Growable[-A] extends Clearable {
5656
* @param xs the IterableOnce producing the elements to $add.
5757
* @return the $coll itself.
5858
*/
59+
@nowarn("msg=will most likely never compare equal")
5960
def addAll(xs: IterableOnce[A]): this.type = {
60-
val it = xs.iterator
61-
while (it.hasNext) {
62-
addOne(it.next())
61+
if (xs.asInstanceOf[AnyRef] eq this) addAll(Buffer.from(xs)) // avoid mutating under our own iterator
62+
else {
63+
val it = xs.iterator
64+
while (it.hasNext) {
65+
addOne(it.next())
66+
}
6367
}
6468
this
6569
}

library/src/scala/collection/mutable/ListBuffer.scala

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,15 @@ import scala.runtime.Statics.releaseFence
3535
* @define mayNotTerminateInf
3636
* @define willNotTerminateInf
3737
*/
38+
@SerialVersionUID(-8428291952499836345L)
3839
class ListBuffer[A]
3940
extends AbstractBuffer[A]
4041
with SeqOps[A, ListBuffer, ListBuffer[A]]
4142
with StrictOptimizedSeqOps[A, ListBuffer, ListBuffer[A]]
4243
with ReusableBuilder[A, immutable.List[A]]
4344
with IterableFactoryDefaults[A, ListBuffer]
4445
with DefaultSerializable {
46+
@transient private[this] var mutationCount: Int = 0
4547

4648
private var first: List[A] = Nil
4749
private var last0: ::[A] = null
@@ -50,7 +52,7 @@ class ListBuffer[A]
5052

5153
private type Predecessor[A0] = ::[A0] /*| Null*/
5254

53-
def iterator = first.iterator
55+
def iterator: Iterator[A] = new MutationTracker.CheckedIterator(first.iterator, mutationCount)
5456

5557
override def iterableFactory: SeqFactory[ListBuffer] = ListBuffer
5658

@@ -69,7 +71,12 @@ class ListBuffer[A]
6971
aliased = false
7072
}
7173

72-
private def ensureUnaliased() = if (aliased) copyElems()
74+
// we only call this before mutating things, so it's
75+
// a good place to track mutations for the iterator
76+
private def ensureUnaliased(): Unit = {
77+
mutationCount += 1
78+
if (aliased) copyElems()
79+
}
7380

7481
// Avoids copying where possible.
7582
override def toList: List[A] = {
@@ -97,6 +104,7 @@ class ListBuffer[A]
97104
}
98105

99106
def clear(): Unit = {
107+
mutationCount += 1
100108
first = Nil
101109
len = 0
102110
last0 = null
@@ -114,18 +122,28 @@ class ListBuffer[A]
114122

115123
// Overridden for performance
116124
override final def addAll(xs: IterableOnce[A]): this.type = {
117-
val it = xs.iterator
118-
if (it.hasNext) {
119-
ensureUnaliased()
120-
val last1 = new ::[A](it.next(), Nil)
121-
if (len == 0) first = last1 else last0.next = last1
122-
last0 = last1
123-
len += 1
124-
while (it.hasNext) {
125+
if (xs.asInstanceOf[AnyRef] eq this) { // avoid mutating under our own iterator
126+
if (len > 0) {
127+
ensureUnaliased()
128+
val copy = ListBuffer.from(this)
129+
last0.next = copy.first
130+
last0 = copy.last0
131+
len *= 2
132+
}
133+
} else {
134+
val it = xs.iterator
135+
if (it.hasNext) {
136+
ensureUnaliased()
125137
val last1 = new ::[A](it.next(), Nil)
126-
last0.next = last1
138+
if (len == 0) first = last1 else last0.next = last1
127139
last0 = last1
128140
len += 1
141+
while (it.hasNext) {
142+
val last1 = new ::[A](it.next(), Nil)
143+
last0.next = last1
144+
last0 = last1
145+
len += 1
146+
}
129147
}
130148
}
131149
this
@@ -230,13 +248,29 @@ class ListBuffer[A]
230248
}
231249

232250
def insertAll(idx: Int, elems: IterableOnce[A]): Unit = {
233-
ensureUnaliased()
234-
val it = elems.iterator
235-
if (it.hasNext) {
236-
ensureUnaliased()
237-
if (idx < 0 || idx > len) throw new IndexOutOfBoundsException(s"$idx is out of bounds (min 0, max ${len-1})")
238-
if (idx == len) ++=(elems)
239-
else insertAfter(locate(idx), it)
251+
if (idx < 0 || idx > len) throw new IndexOutOfBoundsException(s"$idx is out of bounds (min 0, max ${len-1})")
252+
elems match {
253+
case elems: AnyRef if elems eq this => // avoid mutating under our own iterator
254+
if (len > 0) {
255+
val copy = ListBuffer.from(this)
256+
if (idx == 0 || idx == len) { // prepend/append
257+
last0.next = copy.first
258+
last0 = copy.last0
259+
} else {
260+
val prev = locate(idx) // cannot be `null` because other condition catches that
261+
val follow = prev.next
262+
prev.next = copy.first
263+
copy.last0.next = follow
264+
}
265+
len *= 2
266+
}
267+
case elems =>
268+
val it = elems.iterator
269+
if (it.hasNext) {
270+
ensureUnaliased()
271+
if (idx == len) ++=(elems)
272+
else insertAfter(locate(idx), it)
273+
}
240274
}
241275
}
242276

@@ -275,15 +309,17 @@ class ListBuffer[A]
275309
}
276310

277311
def mapInPlace(f: A => A): this.type = {
278-
ensureUnaliased()
312+
mutationCount += 1
279313
val buf = new ListBuffer[A]
280314
for (elem <- this) buf += f(elem)
281315
first = buf.first
282316
last0 = buf.last0
317+
aliased = false // we just assigned from a new instance
283318
this
284319
}
285320

286321
def flatMapInPlace(f: A => IterableOnce[A]): this.type = {
322+
mutationCount += 1
287323
var src = first
288324
var dst: List[A] = null
289325
last0 = null
@@ -299,6 +335,7 @@ class ListBuffer[A]
299335
src = src.tail
300336
}
301337
first = if(dst eq null) Nil else dst
338+
aliased = false // we just rebuilt a fresh, unaliased instance
302339
this
303340
}
304341

@@ -322,12 +359,24 @@ class ListBuffer[A]
322359
}
323360

324361
def patchInPlace(from: Int, patch: collection.IterableOnce[A], replaced: Int): this.type = {
325-
val i = math.min(math.max(from, 0), length)
326-
val n = math.min(math.max(replaced, 0), length)
327-
ensureUnaliased()
328-
val p = locate(i)
329-
removeAfter(p, math.min(n, len - i))
330-
insertAfter(p, patch.iterator)
362+
val _len = len
363+
val _from = math.max(from, 0) // normalized
364+
val _replaced = math.max(replaced, 0) // normalized
365+
val it = patch.iterator
366+
367+
val nonEmptyPatch = it.hasNext
368+
val nonEmptyReplace = (_from < _len) && (_replaced > 0)
369+
370+
// don't want to add a mutation or check aliasing (potentially expensive)
371+
// if there's no patching to do
372+
if (nonEmptyPatch || nonEmptyReplace) {
373+
ensureUnaliased()
374+
val i = math.min(_from, _len)
375+
val n = math.min(_replaced, _len)
376+
val p = locate(i)
377+
removeAfter(p, math.min(n, _len - i))
378+
insertAfter(p, it)
379+
}
331380
this
332381
}
333382

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Scala (https://www.scala-lang.org)
3+
*
4+
* Copyright EPFL and Lightbend, Inc.
5+
*
6+
* Licensed under Apache License 2.0
7+
* (http://www.apache.org/licenses/LICENSE-2.0).
8+
*
9+
* See the NOTICE file distributed with this work for
10+
* additional information regarding copyright ownership.
11+
*/
12+
13+
package scala
14+
package collection
15+
package mutable
16+
17+
import java.util.ConcurrentModificationException
18+
19+
/**
20+
* Utilities to check that mutations to a client that tracks
21+
* its mutations have not occurred since a given point.
22+
* [[Iterator `Iterator`]]s that perform this check automatically
23+
* during iteration can be created by wrapping an `Iterator`
24+
* in a [[MutationTracker.CheckedIterator `CheckedIterator`]],
25+
* or by manually using the [[MutationTracker.checkMutations() `checkMutations`]]
26+
* and [[MutationTracker.checkMutationsForIteration() `checkMutationsForIteration`]]
27+
* methods.
28+
*/
29+
private object MutationTracker {
30+
31+
/**
32+
* Checks whether or not the actual mutation count differs from
33+
* the expected one, throwing an exception, if it does.
34+
*
35+
* @param expectedCount the expected mutation count
36+
* @param actualCount the actual mutation count
37+
* @param message the exception message in case of mutations
38+
* @throws ConcurrentModificationException if the expected and actual
39+
* mutation counts differ
40+
*/
41+
@throws[ConcurrentModificationException]
42+
def checkMutations(expectedCount: Int, actualCount: Int, message: String): Unit = {
43+
if (actualCount != expectedCount) throw new ConcurrentModificationException(message)
44+
}
45+
46+
/**
47+
* Checks whether or not the actual mutation count differs from
48+
* the expected one, throwing an exception, if it does. This method
49+
* produces an exception message saying that it was called because a
50+
* backing collection was mutated during iteration.
51+
*
52+
* @param expectedCount the expected mutation count
53+
* @param actualCount the actual mutation count
54+
* @throws ConcurrentModificationException if the expected and actual
55+
* mutation counts differ
56+
*/
57+
@throws[ConcurrentModificationException]
58+
@inline def checkMutationsForIteration(expectedCount: Int, actualCount: Int): Unit =
59+
checkMutations(expectedCount, actualCount, "mutation occurred during iteration")
60+
61+
/**
62+
* An iterator wrapper that checks if the underlying collection has
63+
* been mutated.
64+
*
65+
* @param underlying the underlying iterator
66+
* @param mutationCount a by-name provider of the current mutation count
67+
* @tparam A the type of the iterator's elements
68+
*/
69+
final class CheckedIterator[A](underlying: Iterator[A], mutationCount: => Int) extends AbstractIterator[A] {
70+
private[this] val expectedCount = mutationCount
71+
72+
def hasNext: Boolean = {
73+
checkMutationsForIteration(expectedCount, mutationCount)
74+
underlying.hasNext
75+
}
76+
def next(): A = underlying.next()
77+
}
78+
}

library/src/scala/collection/mutable/Shrinkable.scala

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
package scala
1414
package collection.mutable
1515

16-
import scala.annotation.tailrec
16+
import scala.annotation.{nowarn, tailrec}
1717

1818
/** This trait forms part of collections that can be reduced
1919
* using a `-=` operator.
@@ -52,16 +52,24 @@ trait Shrinkable[-A] {
5252
* @param xs the iterator producing the elements to remove.
5353
* @return the $coll itself
5454
*/
55+
@nowarn("msg=will most likely never compare equal")
5556
def subtractAll(xs: collection.IterableOnce[A]): this.type = {
5657
@tailrec def loop(xs: collection.LinearSeq[A]): Unit = {
5758
if (xs.nonEmpty) {
5859
subtractOne(xs.head)
5960
loop(xs.tail)
6061
}
6162
}
62-
xs match {
63-
case xs: collection.LinearSeq[A] => loop(xs)
64-
case xs => xs.iterator.foreach(subtractOne)
63+
if (xs.asInstanceOf[AnyRef] eq this) { // avoid mutating under our own iterator
64+
xs match {
65+
case xs: Clearable => xs.clear()
66+
case xs => subtractAll(Buffer.from(xs))
67+
}
68+
} else {
69+
xs match {
70+
case xs: collection.LinearSeq[A] => loop(xs)
71+
case xs => xs.iterator.foreach(subtractOne)
72+
}
6573
}
6674
this
6775
}

0 commit comments

Comments
 (0)