Skip to content

Commit 842d2de

Browse files
committed
LongHashedRelation off-heap
1 parent acfedac commit 842d2de

File tree

1 file changed

+77
-69
lines changed

1 file changed

+77
-69
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala

Lines changed: 77 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.io._
2222
import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
2323
import com.esotericsoftware.kryo.io.{Input, Output}
2424

25-
import org.apache.spark.{SparkConf, SparkEnv, SparkException, SparkUnsupportedOperationException}
25+
import org.apache.spark.{SparkConf, SparkEnv, SparkUnsupportedOperationException}
2626
import org.apache.spark.internal.config.{BUFFER_PAGESIZE, MEMORY_OFFHEAP_ENABLED}
2727
import org.apache.spark.memory._
2828
import org.apache.spark.sql.catalyst.InternalRow
@@ -32,6 +32,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
3232
import org.apache.spark.sql.types.LongType
3333
import org.apache.spark.unsafe.Platform
3434
import org.apache.spark.unsafe.map.BytesToBytesMap
35+
import org.apache.spark.unsafe.memory.MemoryBlock
3536
import org.apache.spark.util.{KnownSizeEstimation, Utils}
3637

3738
/**
@@ -535,7 +536,7 @@ private[execution] final class LongToUnsafeRowMap(
535536
val mm: TaskMemoryManager,
536537
capacity: Int,
537538
ignoresDuplicatedKey: Boolean = false)
538-
extends MemoryConsumer(mm, MemoryMode.ON_HEAP) with Externalizable with KryoSerializable {
539+
extends MemoryConsumer(mm, mm.getTungstenMemoryMode) with Externalizable with KryoSerializable {
539540

540541
// Whether the keys are stored in dense mode or not.
541542
private var isDense = false
@@ -550,15 +551,15 @@ private[execution] final class LongToUnsafeRowMap(
550551
//
551552
// Sparse mode: [key1] [offset1 | size1] [key2] [offset | size2] ...
552553
// Dense mode: [offset1 | size1] [offset2 | size2]
553-
private var array: Array[Long] = null
554+
private var array: UnsafeLongArray = null
554555
private var mask: Int = 0
555556

556557
// The page to store all bytes of UnsafeRow and the pointer to next rows.
557558
// [row1][pointer1] [row2][pointer2]
558-
private var page: Array[Long] = null
559+
private var page: MemoryBlock = null
559560

560561
// Current write cursor in the page.
561-
private var cursor: Long = Platform.LONG_ARRAY_OFFSET
562+
private var cursor: Long = -1
562563

563564
// The number of bits for size in address
564565
private val SIZE_BITS = 28
@@ -583,24 +584,15 @@ private[execution] final class LongToUnsafeRowMap(
583584
0)
584585
}
585586

586-
private def ensureAcquireMemory(size: Long): Unit = {
587-
// do not support spilling
588-
val got = acquireMemory(size)
589-
if (got < size) {
590-
freeMemory(got)
591-
throw QueryExecutionErrors.cannotAcquireMemoryToBuildLongHashedRelationError(size, got)
592-
}
593-
}
594-
595587
private def init(): Unit = {
596588
if (mm != null) {
597589
require(capacity < 512000000, "Cannot broadcast 512 million or more rows")
598590
var n = 1
599591
while (n < capacity) n *= 2
600-
ensureAcquireMemory(n * 2L * 8 + (1 << 20))
601-
array = new Array[Long](n * 2)
592+
array = new UnsafeLongArray(n * 2)
602593
mask = n * 2 - 2
603-
page = new Array[Long](1 << 17) // 1M bytes
594+
page = allocatePage(1 << 20)// 1M bytes
595+
cursor = page.getBaseOffset
604596
}
605597
}
606598

@@ -616,7 +608,7 @@ private[execution] final class LongToUnsafeRowMap(
616608
/**
617609
* Returns total memory consumption.
618610
*/
619-
def getTotalMemoryConsumption: Long = array.length * 8L + page.length * 8L
611+
def getTotalMemoryConsumption: Long = array.length * 8L + page.size()
620612

621613
/**
622614
* Returns the first slot of array that store the keys (sparse mode).
@@ -632,19 +624,19 @@ private[execution] final class LongToUnsafeRowMap(
632624
private def nextSlot(pos: Int): Int = (pos + 2) & mask
633625

634626
private[this] def toAddress(offset: Long, size: Int): Long = {
635-
((offset - Platform.LONG_ARRAY_OFFSET) << SIZE_BITS) | size
627+
(offset << SIZE_BITS) | size
636628
}
637629

638630
private[this] def toOffset(address: Long): Long = {
639-
(address >>> SIZE_BITS) + Platform.LONG_ARRAY_OFFSET
631+
(address >>> SIZE_BITS)
640632
}
641633

642634
private[this] def toSize(address: Long): Int = {
643635
(address & SIZE_MASK).toInt
644636
}
645637

646638
private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = {
647-
resultRow.pointTo(page, toOffset(address), toSize(address))
639+
resultRow.pointTo(page.getBaseObject, page.getBaseOffset + toOffset(address), toSize(address))
648640
resultRow
649641
}
650642

@@ -681,8 +673,8 @@ private[execution] final class LongToUnsafeRowMap(
681673
override def next(): UnsafeRow = {
682674
val offset = toOffset(addr)
683675
val size = toSize(addr)
684-
resultRow.pointTo(page, offset, size)
685-
addr = Platform.getLong(page, offset + size)
676+
resultRow.pointTo(page.getBaseObject, page.getBaseOffset + offset, size)
677+
addr = Platform.getLong(page.getBaseObject, page.getBaseOffset + offset + size)
686678
resultRow
687679
}
688680
}
@@ -777,12 +769,13 @@ private[execution] final class LongToUnsafeRowMap(
777769

778770
// copy the bytes of UnsafeRow
779771
val offset = cursor
780-
Platform.copyMemory(row.getBaseObject, row.getBaseOffset, page, cursor, row.getSizeInBytes)
772+
Platform.copyMemory(row.getBaseObject, row.getBaseOffset, page.getBaseObject, cursor,
773+
row.getSizeInBytes)
781774
cursor += row.getSizeInBytes
782-
Platform.putLong(page, cursor, 0)
775+
Platform.putLong(page.getBaseObject, cursor, 0)
783776
cursor += 8
784777
numValues += 1
785-
updateIndex(key, pos, toAddress(offset, row.getSizeInBytes))
778+
updateIndex(key, pos, toAddress(offset - page.getBaseOffset, row.getSizeInBytes))
786779
}
787780

788781
private def findKeyPosition(key: Long): Int = {
@@ -816,35 +809,32 @@ private[execution] final class LongToUnsafeRowMap(
816809
} else {
817810
// there are some values for this key, put the address in the front of them.
818811
val pointer = toOffset(address) + toSize(address)
819-
Platform.putLong(page, pointer, array(pos + 1))
812+
Platform.putLong(page.getBaseObject, page.getBaseOffset + pointer, array(pos + 1))
820813
array(pos + 1) = address
821814
}
822815
}
823816

824817
private def grow(inputRowSize: Int): Unit = {
825818
// There is 8 bytes for the pointer to next value
826-
val neededNumWords = (cursor - Platform.LONG_ARRAY_OFFSET + 8 + inputRowSize + 7) / 8
827-
if (neededNumWords > page.length) {
819+
val neededNumWords = (cursor - page.getBaseOffset + 8 + inputRowSize + 7) / 8
820+
if (neededNumWords > page.size() / 8) {
828821
if (neededNumWords > (1 << 30)) {
829822
throw QueryExecutionErrors.cannotBuildHashedRelationLargerThan8GError()
830823
}
831-
val newNumWords = math.max(neededNumWords, math.min(page.length * 2, 1 << 30))
832-
ensureAcquireMemory(newNumWords * 8L)
833-
val newPage = new Array[Long](newNumWords.toInt)
834-
Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
835-
cursor - Platform.LONG_ARRAY_OFFSET)
836-
val used = page.length
824+
val newNumWords = math.max(neededNumWords, math.min(page.size() / 8 * 2, 1 << 30))
825+
val newPage = allocatePage(newNumWords.toInt * 8)
826+
Platform.copyMemory(page.getBaseObject, page.getBaseOffset, newPage.getBaseObject,
827+
newPage.getBaseOffset, cursor - page.getBaseOffset)
828+
freePage(page)
837829
page = newPage
838-
freeMemory(used * 8L)
839830
}
840831
}
841832

842833
private def growArray(): Unit = {
843834
var old_array = array
844835
val n = array.length
845836
numKeys = 0
846-
ensureAcquireMemory(n * 2 * 8L)
847-
array = new Array[Long](n * 2)
837+
array = new UnsafeLongArray(n * 2)
848838
mask = n * 2 - 2
849839
var i = 0
850840
while (i < old_array.length) {
@@ -854,8 +844,8 @@ private[execution] final class LongToUnsafeRowMap(
854844
}
855845
i += 2
856846
}
847+
old_array.free()
857848
old_array = null // release the reference to old array
858-
freeMemory(n * 8L)
859849
}
860850

861851
/**
@@ -866,14 +856,7 @@ private[execution] final class LongToUnsafeRowMap(
866856
// Convert to dense mode if it does not require more memory or could fit within L1 cache
867857
// SPARK-16740: Make sure range doesn't overflow if minKey has a large negative value
868858
if (range >= 0 && (range < array.length || range < 1024)) {
869-
try {
870-
ensureAcquireMemory((range + 1) * 8L)
871-
} catch {
872-
case e: SparkException =>
873-
// there is no enough memory to convert
874-
return
875-
}
876-
val denseArray = new Array[Long]((range + 1).toInt)
859+
val denseArray = new UnsafeLongArray((range + 1).toInt)
877860
var i = 0
878861
while (i < array.length) {
879862
if (array(i + 1) > 0) {
@@ -882,10 +865,9 @@ private[execution] final class LongToUnsafeRowMap(
882865
}
883866
i += 2
884867
}
885-
val old_length = array.length
868+
array.free()
886869
array = denseArray
887870
isDense = true
888-
freeMemory(old_length * 8L)
889871
}
890872
}
891873

@@ -894,25 +876,26 @@ private[execution] final class LongToUnsafeRowMap(
894876
*/
895877
def free(): Unit = {
896878
if (page != null) {
897-
freeMemory(page.length * 8L)
879+
freePage(page)
898880
page = null
899881
}
900882
if (array != null) {
901-
freeMemory(array.length * 8L)
883+
array.free()
902884
array = null
903885
}
904886
}
905887

906-
private def writeLongArray(
888+
private def writeBytes(
907889
writeBuffer: (Array[Byte], Int, Int) => Unit,
908-
arr: Array[Long],
890+
baseObject: Object,
891+
baseOffset: Long,
909892
len: Int): Unit = {
910893
val buffer = new Array[Byte](4 << 10)
911-
var offset: Long = Platform.LONG_ARRAY_OFFSET
912-
val end = len * 8L + Platform.LONG_ARRAY_OFFSET
894+
var offset: Long = baseOffset
895+
val end = len * 8L + offset
913896
while (offset < end) {
914897
val size = Math.min(buffer.length, end - offset)
915-
Platform.copyMemory(arr, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size)
898+
Platform.copyMemory(baseObject, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size)
916899
writeBuffer(buffer, 0, size.toInt)
917900
offset += size
918901
}
@@ -929,10 +912,11 @@ private[execution] final class LongToUnsafeRowMap(
929912
writeLong(numValues)
930913

931914
writeLong(array.length)
932-
writeLongArray(writeBuffer, array, array.length)
933-
val used = ((cursor - Platform.LONG_ARRAY_OFFSET) / 8).toInt
915+
writeBytes(writeBuffer,
916+
array.memoryBlock.getBaseObject, array.memoryBlock.getBaseOffset, array.length)
917+
val used = ((cursor - page.getBaseOffset) / 8).toInt
934918
writeLong(used)
935-
writeLongArray(writeBuffer, page, used)
919+
writeBytes(writeBuffer, page.getBaseObject, page.getBaseOffset, used)
936920
}
937921

938922
override def writeExternal(output: ObjectOutput): Unit = {
@@ -943,20 +927,20 @@ private[execution] final class LongToUnsafeRowMap(
943927
write(out.writeBoolean, out.writeLong, out.write)
944928
}
945929

946-
private def readLongArray(
930+
private def readData(
947931
readBuffer: (Array[Byte], Int, Int) => Unit,
948-
length: Int): Array[Long] = {
949-
val array = new Array[Long](length)
932+
baseObject: Object,
933+
baseOffset: Long,
934+
length: Int): Unit = {
950935
val buffer = new Array[Byte](4 << 10)
951-
var offset: Long = Platform.LONG_ARRAY_OFFSET
952-
val end = length * 8L + Platform.LONG_ARRAY_OFFSET
936+
var offset: Long = baseOffset
937+
val end = length * 8L + baseOffset
953938
while (offset < end) {
954939
val size = Math.min(buffer.length, end - offset)
955940
readBuffer(buffer, 0, size.toInt)
956-
Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size)
941+
Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, baseObject, offset, size)
957942
offset += size
958943
}
959-
array
960944
}
961945

962946
private def read(
@@ -971,11 +955,15 @@ private[execution] final class LongToUnsafeRowMap(
971955

972956
val length = readLong().toInt
973957
mask = length - 2
974-
array = readLongArray(readBuffer, length)
958+
array.free()
959+
array = new UnsafeLongArray(length)
960+
readData(readBuffer, array.memoryBlock.getBaseObject, array.memoryBlock.getBaseOffset, length)
975961
val pageLength = readLong().toInt
976-
page = readLongArray(readBuffer, pageLength)
962+
freePage(page)
963+
page = allocatePage(pageLength * 8)
964+
readData(readBuffer, page.getBaseObject, page.getBaseOffset, pageLength)
977965
// Restore cursor variable to make this map able to be serialized again on executors.
978-
cursor = pageLength * 8 + Platform.LONG_ARRAY_OFFSET
966+
cursor = pageLength * 8 + page.getBaseOffset
979967
}
980968

981969
override def readExternal(in: ObjectInput): Unit = {
@@ -985,6 +973,26 @@ private[execution] final class LongToUnsafeRowMap(
985973
override def read(kryo: Kryo, in: Input): Unit = {
986974
read(() => in.readBoolean(), () => in.readLong(), in.readBytes)
987975
}
976+
977+
private class UnsafeLongArray(val length: Int) {
978+
val memoryBlock = allocatePage(length * 8)
979+
980+
for (i <- 0 until length) {
981+
update(i, 0)
982+
}
983+
984+
def apply(index: Int): Long = {
985+
Platform.getLong(memoryBlock.getBaseObject, memoryBlock.getBaseOffset + index * 8)
986+
}
987+
988+
def update(index: Int, value: Long): Unit = {
989+
Platform.putLong(memoryBlock.getBaseObject, memoryBlock.getBaseOffset + index * 8, value)
990+
}
991+
992+
def free(): Unit = {
993+
freePage(memoryBlock)
994+
}
995+
}
988996
}
989997

990998
class LongHashedRelation(

0 commit comments

Comments
 (0)