Skip to content

Commit ce147aa

Browse files
authored
Merge pull request scala/scala#8073 from lihaoyi/fast-update-with
override mutable.HashMap#updateWith and mutable.LinkedHashMap#updateWith for performance
2 parents d99ad20 + 99104cf commit ce147aa

File tree

3 files changed

+95
-10
lines changed

3 files changed

+95
-10
lines changed

library/src/scala/collection/mutable/HashMap.scala

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,57 @@ class HashMap[K, V](initialCapacity: Int, loadFactor: Double)
106106
}
107107
}
108108

109+
// Override updateWith for performance, so we can do the update while hashing
110+
// the input key only once and performing one lookup into the hash table
111+
override def updateWith(key: K)(remappingFunction: Option[V] => Option[V]): Option[V] = {
112+
if (getClass != classOf[HashMap[_, _]]) {
113+
// subclasses of HashMap might customise `get` ...
114+
super.updateWith(key)(remappingFunction)
115+
} else {
116+
val hash = computeHash(key)
117+
val indexedHash = index(hash)
118+
119+
var foundNode: Node[K, V] = null
120+
var previousNode: Node[K, V] = null
121+
table(indexedHash) match {
122+
case null =>
123+
case nd =>
124+
@tailrec
125+
def findNode(prev: Node[K, V], nd: Node[K, V], k: K, h: Int): Unit = {
126+
if (h == nd.hash && k == nd.key) {
127+
previousNode = prev
128+
foundNode = nd
129+
}
130+
else if ((nd.next eq null) || (nd.hash > h)) ()
131+
else findNode(nd, nd.next, k, h)
132+
}
133+
134+
findNode(null, nd, key, hash)
135+
}
136+
137+
val previousValue = foundNode match {
138+
case null => None
139+
case nd => Some(nd.value)
140+
}
141+
142+
val nextValue = remappingFunction(previousValue)
143+
144+
(previousValue, nextValue) match {
145+
case (None, None) => // do nothing
146+
147+
case (Some(_), None) =>
148+
if (previousNode != null) previousNode.next = foundNode.next
149+
else table(indexedHash) = foundNode.next
150+
contentSize -= 1
151+
152+
case (None, Some(value)) => put0(key, value, false, hash, indexedHash)
153+
154+
case (Some(_), Some(newValue)) => foundNode.value = newValue
155+
}
156+
nextValue
157+
}
158+
}
159+
109160
override def subtractAll(xs: IterableOnce[K]): this.type = {
110161
if (size == 0) {
111162
return this

library/src/scala/collection/mutable/HashTable.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,11 @@ private[collection] /*abstract class*/ trait HashTable[A, B, Entry >: Null <: Ha
176176
/** Remove entry from table if present.
177177
*/
178178
final def removeEntry(key: A) : Entry = {
179-
val h = index(elemHashCode(key))
179+
removeEntry0(key, index(elemHashCode(key)))
180+
}
181+
/** Remove entry from table if present.
182+
*/
183+
private[collection] final def removeEntry0(key: A, h: Int) : Entry = {
180184
var e = table(h).asInstanceOf[Entry]
181185
if (e != null) {
182186
if (elemEquals(e.key, key)) {

library/src/scala/collection/mutable/LinkedHashMap.scala

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -139,15 +139,17 @@ class LinkedHashMap[K, V]
139139
override def remove(key: K): Option[V] = {
140140
val e = table.removeEntry(key)
141141
if (e eq null) None
142-
else {
143-
if (e.earlier eq null) firstEntry = e.later
144-
else e.earlier.later = e.later
145-
if (e.later eq null) lastEntry = e.earlier
146-
else e.later.earlier = e.earlier
147-
e.earlier = null // Null references to prevent nepotism
148-
e.later = null
149-
Some(e.value)
150-
}
142+
else Some(remove0(e))
143+
}
144+
145+
private[this] def remove0(e: Entry): V = {
146+
if (e.earlier eq null) firstEntry = e.later
147+
else e.earlier.later = e.later
148+
if (e.later eq null) lastEntry = e.earlier
149+
else e.later.earlier = e.earlier
150+
e.earlier = null // Null references to prevent nepotism
151+
e.later = null
152+
e.value
151153
}
152154

153155
def addOne(kv: (K, V)): this.type = { put(kv._1, kv._2); this }
@@ -176,6 +178,34 @@ class LinkedHashMap[K, V]
176178
else Iterator.empty.next()
177179
}
178180

181+
// Override updateWith for performance, so we can do the update while hashing
182+
// the input key only once and performing one lookup into the hash table
183+
override def updateWith(key: K)(remappingFunction: Option[V] => Option[V]): Option[V] = {
184+
val keyIndex = table.index(table.elemHashCode(key))
185+
val entry = table.findEntry0(key, keyIndex)
186+
187+
val previousValue =
188+
if (entry == null) None
189+
else Some(entry.value)
190+
191+
val nextValue = remappingFunction(previousValue)
192+
193+
(previousValue, nextValue) match {
194+
case (None, None) => // do nothing
195+
case (Some(_), None) =>
196+
remove0(entry)
197+
table.removeEntry0(key, keyIndex)
198+
199+
case (None, Some(value)) =>
200+
table.addEntry0(table.createNewEntry(key, value), keyIndex)
201+
202+
case (Some(_), Some(value)) =>
203+
entry.value = value
204+
}
205+
206+
nextValue
207+
}
208+
179209
override def valuesIterator: Iterator[V] = new AbstractIterator[V] {
180210
private[this] var cur = firstEntry
181211
def hasNext = cur ne null

0 commit comments

Comments
 (0)