@@ -22,7 +22,7 @@ import java.io._
2222import com .esotericsoftware .kryo .{Kryo , KryoSerializable }
2323import com .esotericsoftware .kryo .io .{Input , Output }
2424
25- import org .apache .spark .{SparkConf , SparkEnv , SparkException , SparkUnsupportedOperationException }
25+ import org .apache .spark .{SparkConf , SparkEnv , SparkUnsupportedOperationException }
2626import org .apache .spark .internal .config .{BUFFER_PAGESIZE , MEMORY_OFFHEAP_ENABLED }
2727import org .apache .spark .memory ._
2828import org .apache .spark .sql .catalyst .InternalRow
@@ -32,6 +32,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
3232import org .apache .spark .sql .types .LongType
3333import org .apache .spark .unsafe .Platform
3434import org .apache .spark .unsafe .map .BytesToBytesMap
35+ import org .apache .spark .unsafe .memory .MemoryBlock
3536import org .apache .spark .util .{KnownSizeEstimation , Utils }
3637
3738/**
@@ -535,7 +536,7 @@ private[execution] final class LongToUnsafeRowMap(
535536 val mm : TaskMemoryManager ,
536537 capacity : Int ,
537538 ignoresDuplicatedKey : Boolean = false )
538- extends MemoryConsumer (mm, MemoryMode . ON_HEAP ) with Externalizable with KryoSerializable {
539+ extends MemoryConsumer (mm, mm.getTungstenMemoryMode ) with Externalizable with KryoSerializable {
539540
540541 // Whether the keys are stored in dense mode or not.
541542 private var isDense = false
@@ -550,15 +551,15 @@ private[execution] final class LongToUnsafeRowMap(
550551 //
551552 // Sparse mode: [key1] [offset1 | size1] [key2] [offset | size2] ...
552553 // Dense mode: [offset1 | size1] [offset2 | size2]
553- private var array : Array [ Long ] = null
554+ private var array : UnsafeLongArray = null
554555 private var mask : Int = 0
555556
556557 // The page to store all bytes of UnsafeRow and the pointer to next rows.
557558 // [row1][pointer1] [row2][pointer2]
558- private var page : Array [ Long ] = null
559+ private var page : MemoryBlock = null
559560
560561 // Current write cursor in the page.
561- private var cursor : Long = Platform . LONG_ARRAY_OFFSET
562+ private var cursor : Long = - 1
562563
563564 // The number of bits for size in address
564565 private val SIZE_BITS = 28
@@ -583,24 +584,15 @@ private[execution] final class LongToUnsafeRowMap(
583584 0 )
584585 }
585586
586- private def ensureAcquireMemory (size : Long ): Unit = {
587- // do not support spilling
588- val got = acquireMemory(size)
589- if (got < size) {
590- freeMemory(got)
591- throw QueryExecutionErrors .cannotAcquireMemoryToBuildLongHashedRelationError(size, got)
592- }
593- }
594-
595587 private def init (): Unit = {
596588 if (mm != null ) {
597589 require(capacity < 512000000 , " Cannot broadcast 512 million or more rows" )
598590 var n = 1
599591 while (n < capacity) n *= 2
600- ensureAcquireMemory(n * 2L * 8 + (1 << 20 ))
601- array = new Array [Long ](n * 2 )
592+ array = new UnsafeLongArray (n * 2 )
602593 mask = n * 2 - 2
603- page = new Array [Long ](1 << 17 ) // 1M bytes
594+ page = allocatePage(1 << 20 )// 1M bytes
595+ cursor = page.getBaseOffset
604596 }
605597 }
606598
@@ -616,7 +608,7 @@ private[execution] final class LongToUnsafeRowMap(
616608 /**
617609 * Returns total memory consumption.
618610 */
619- def getTotalMemoryConsumption : Long = array.length * 8L + page.length * 8L
611+ def getTotalMemoryConsumption : Long = array.length * 8L + page.size()
620612
621613 /**
622614 * Returns the first slot of array that store the keys (sparse mode).
@@ -632,19 +624,19 @@ private[execution] final class LongToUnsafeRowMap(
632624 private def nextSlot (pos : Int ): Int = (pos + 2 ) & mask
633625
634626 private [this ] def toAddress (offset : Long , size : Int ): Long = {
635- (( offset - Platform . LONG_ARRAY_OFFSET ) << SIZE_BITS ) | size
627+ (offset << SIZE_BITS ) | size
636628 }
637629
638630 private [this ] def toOffset (address : Long ): Long = {
639- (address >>> SIZE_BITS ) + Platform . LONG_ARRAY_OFFSET
631+ (address >>> SIZE_BITS )
640632 }
641633
642634 private [this ] def toSize (address : Long ): Int = {
643635 (address & SIZE_MASK ).toInt
644636 }
645637
646638 private def getRow (address : Long , resultRow : UnsafeRow ): UnsafeRow = {
647- resultRow.pointTo(page, toOffset(address), toSize(address))
639+ resultRow.pointTo(page.getBaseObject, page.getBaseOffset + toOffset(address), toSize(address))
648640 resultRow
649641 }
650642
@@ -681,8 +673,8 @@ private[execution] final class LongToUnsafeRowMap(
681673 override def next (): UnsafeRow = {
682674 val offset = toOffset(addr)
683675 val size = toSize(addr)
684- resultRow.pointTo(page, offset, size)
685- addr = Platform .getLong(page, offset + size)
676+ resultRow.pointTo(page.getBaseObject, page.getBaseOffset + offset, size)
677+ addr = Platform .getLong(page.getBaseObject, page.getBaseOffset + offset + size)
686678 resultRow
687679 }
688680 }
@@ -777,12 +769,13 @@ private[execution] final class LongToUnsafeRowMap(
777769
778770 // copy the bytes of UnsafeRow
779771 val offset = cursor
780- Platform .copyMemory(row.getBaseObject, row.getBaseOffset, page, cursor, row.getSizeInBytes)
772+ Platform .copyMemory(row.getBaseObject, row.getBaseOffset, page.getBaseObject, cursor,
773+ row.getSizeInBytes)
781774 cursor += row.getSizeInBytes
782- Platform .putLong(page, cursor, 0 )
775+ Platform .putLong(page.getBaseObject , cursor, 0 )
783776 cursor += 8
784777 numValues += 1
785- updateIndex(key, pos, toAddress(offset, row.getSizeInBytes))
778+ updateIndex(key, pos, toAddress(offset - page.getBaseOffset , row.getSizeInBytes))
786779 }
787780
788781 private def findKeyPosition (key : Long ): Int = {
@@ -816,35 +809,32 @@ private[execution] final class LongToUnsafeRowMap(
816809 } else {
817810 // there are some values for this key, put the address in the front of them.
818811 val pointer = toOffset(address) + toSize(address)
819- Platform .putLong(page, pointer, array(pos + 1 ))
812+ Platform .putLong(page.getBaseObject, page.getBaseOffset + pointer, array(pos + 1 ))
820813 array(pos + 1 ) = address
821814 }
822815 }
823816
824817 private def grow (inputRowSize : Int ): Unit = {
825818 // There is 8 bytes for the pointer to next value
826- val neededNumWords = (cursor - Platform . LONG_ARRAY_OFFSET + 8 + inputRowSize + 7 ) / 8
827- if (neededNumWords > page.length ) {
819+ val neededNumWords = (cursor - page.getBaseOffset + 8 + inputRowSize + 7 ) / 8
820+ if (neededNumWords > page.size() / 8 ) {
828821 if (neededNumWords > (1 << 30 )) {
829822 throw QueryExecutionErrors .cannotBuildHashedRelationLargerThan8GError()
830823 }
831- val newNumWords = math.max(neededNumWords, math.min(page.length * 2 , 1 << 30 ))
832- ensureAcquireMemory(newNumWords * 8L )
833- val newPage = new Array [Long ](newNumWords.toInt)
834- Platform .copyMemory(page, Platform .LONG_ARRAY_OFFSET , newPage, Platform .LONG_ARRAY_OFFSET ,
835- cursor - Platform .LONG_ARRAY_OFFSET )
836- val used = page.length
824+ val newNumWords = math.max(neededNumWords, math.min(page.size() / 8 * 2 , 1 << 30 ))
825+ val newPage = allocatePage(newNumWords.toInt * 8 )
826+ Platform .copyMemory(page.getBaseObject, page.getBaseOffset, newPage.getBaseObject,
827+ newPage.getBaseOffset, cursor - page.getBaseOffset)
828+ freePage(page)
837829 page = newPage
838- freeMemory(used * 8L )
839830 }
840831 }
841832
842833 private def growArray (): Unit = {
843834 var old_array = array
844835 val n = array.length
845836 numKeys = 0
846- ensureAcquireMemory(n * 2 * 8L )
847- array = new Array [Long ](n * 2 )
837+ array = new UnsafeLongArray (n * 2 )
848838 mask = n * 2 - 2
849839 var i = 0
850840 while (i < old_array.length) {
@@ -854,8 +844,8 @@ private[execution] final class LongToUnsafeRowMap(
854844 }
855845 i += 2
856846 }
847+ old_array.free()
857848 old_array = null // release the reference to old array
858- freeMemory(n * 8L )
859849 }
860850
861851 /**
@@ -866,14 +856,7 @@ private[execution] final class LongToUnsafeRowMap(
866856 // Convert to dense mode if it does not require more memory or could fit within L1 cache
867857 // SPARK-16740: Make sure range doesn't overflow if minKey has a large negative value
868858 if (range >= 0 && (range < array.length || range < 1024 )) {
869- try {
870- ensureAcquireMemory((range + 1 ) * 8L )
871- } catch {
872- case e : SparkException =>
873- // there is no enough memory to convert
874- return
875- }
876- val denseArray = new Array [Long ]((range + 1 ).toInt)
859+ val denseArray = new UnsafeLongArray ((range + 1 ).toInt)
877860 var i = 0
878861 while (i < array.length) {
879862 if (array(i + 1 ) > 0 ) {
@@ -882,10 +865,9 @@ private[execution] final class LongToUnsafeRowMap(
882865 }
883866 i += 2
884867 }
885- val old_length = array.length
868+ array.free()
886869 array = denseArray
887870 isDense = true
888- freeMemory(old_length * 8L )
889871 }
890872 }
891873
@@ -894,25 +876,26 @@ private[execution] final class LongToUnsafeRowMap(
894876 */
895877 def free (): Unit = {
896878 if (page != null ) {
897- freeMemory (page.length * 8L )
879+ freePage (page)
898880 page = null
899881 }
900882 if (array != null ) {
901- freeMemory( array.length * 8L )
883+ array.free( )
902884 array = null
903885 }
904886 }
905887
906- private def writeLongArray (
888+ private def writeBytes (
907889 writeBuffer : (Array [Byte ], Int , Int ) => Unit ,
908- arr : Array [Long ],
890+ baseObject : Object ,
891+ baseOffset : Long ,
909892 len : Int ): Unit = {
910893 val buffer = new Array [Byte ](4 << 10 )
911- var offset : Long = Platform . LONG_ARRAY_OFFSET
912- val end = len * 8L + Platform . LONG_ARRAY_OFFSET
894+ var offset : Long = baseOffset
895+ val end = len * 8L + offset
913896 while (offset < end) {
914897 val size = Math .min(buffer.length, end - offset)
915- Platform .copyMemory(arr , offset, buffer, Platform .BYTE_ARRAY_OFFSET , size)
898+ Platform .copyMemory(baseObject , offset, buffer, Platform .BYTE_ARRAY_OFFSET , size)
916899 writeBuffer(buffer, 0 , size.toInt)
917900 offset += size
918901 }
@@ -929,10 +912,11 @@ private[execution] final class LongToUnsafeRowMap(
929912 writeLong(numValues)
930913
931914 writeLong(array.length)
932- writeLongArray(writeBuffer, array, array.length)
933- val used = ((cursor - Platform .LONG_ARRAY_OFFSET ) / 8 ).toInt
915+ writeBytes(writeBuffer,
916+ array.memoryBlock.getBaseObject, array.memoryBlock.getBaseOffset, array.length)
917+ val used = ((cursor - page.getBaseOffset) / 8 ).toInt
934918 writeLong(used)
935- writeLongArray (writeBuffer, page, used)
919+ writeBytes (writeBuffer, page.getBaseObject, page.getBaseOffset , used)
936920 }
937921
938922 override def writeExternal (output : ObjectOutput ): Unit = {
@@ -943,20 +927,20 @@ private[execution] final class LongToUnsafeRowMap(
943927 write(out.writeBoolean, out.writeLong, out.write)
944928 }
945929
946- private def readLongArray (
930+ private def readData (
947931 readBuffer : (Array [Byte ], Int , Int ) => Unit ,
948- length : Int ): Array [Long ] = {
949- val array = new Array [Long ](length)
932+ baseObject : Object ,
933+ baseOffset : Long ,
934+ length : Int ): Unit = {
950935 val buffer = new Array [Byte ](4 << 10 )
951- var offset : Long = Platform . LONG_ARRAY_OFFSET
952- val end = length * 8L + Platform . LONG_ARRAY_OFFSET
936+ var offset : Long = baseOffset
937+ val end = length * 8L + baseOffset
953938 while (offset < end) {
954939 val size = Math .min(buffer.length, end - offset)
955940 readBuffer(buffer, 0 , size.toInt)
956- Platform .copyMemory(buffer, Platform .BYTE_ARRAY_OFFSET , array , offset, size)
941+ Platform .copyMemory(buffer, Platform .BYTE_ARRAY_OFFSET , baseObject , offset, size)
957942 offset += size
958943 }
959- array
960944 }
961945
962946 private def read (
@@ -971,11 +955,15 @@ private[execution] final class LongToUnsafeRowMap(
971955
972956 val length = readLong().toInt
973957 mask = length - 2
974- array = readLongArray(readBuffer, length)
958+ array.free()
959+ array = new UnsafeLongArray (length)
960+ readData(readBuffer, array.memoryBlock.getBaseObject, array.memoryBlock.getBaseOffset, length)
975961 val pageLength = readLong().toInt
976- page = readLongArray(readBuffer, pageLength)
962+ freePage(page)
963+ page = allocatePage(pageLength * 8 )
964+ readData(readBuffer, page.getBaseObject, page.getBaseOffset, pageLength)
977965 // Restore cursor variable to make this map able to be serialized again on executors.
978- cursor = pageLength * 8 + Platform . LONG_ARRAY_OFFSET
966+ cursor = pageLength * 8 + page.getBaseOffset
979967 }
980968
981969 override def readExternal (in : ObjectInput ): Unit = {
@@ -985,6 +973,26 @@ private[execution] final class LongToUnsafeRowMap(
985973 override def read (kryo : Kryo , in : Input ): Unit = {
986974 read(() => in.readBoolean(), () => in.readLong(), in.readBytes)
987975 }
976+
977+ private class UnsafeLongArray (val length : Int ) {
978+ val memoryBlock = allocatePage(length * 8 )
979+
980+ for (i <- 0 until length) {
981+ update(i, 0 )
982+ }
983+
984+ def apply (index : Int ): Long = {
985+ Platform .getLong(memoryBlock.getBaseObject, memoryBlock.getBaseOffset + index * 8 )
986+ }
987+
988+ def update (index : Int , value : Long ): Unit = {
989+ Platform .putLong(memoryBlock.getBaseObject, memoryBlock.getBaseOffset + index * 8 , value)
990+ }
991+
992+ def free (): Unit = {
993+ freePage(memoryBlock)
994+ }
995+ }
988996}
989997
990998class LongHashedRelation (
0 commit comments