From d43563d3fc6228c1eb21d9a962f12b2608b52744 Mon Sep 17 00:00:00 2001 From: Greg Roodt Date: Thu, 27 Aug 2020 12:53:31 +1000 Subject: [PATCH 1/7] wip --- pom.xml | 45 +- .../tfrecord/SharedSparkSessionSuite.scala | 33 +- .../tfrecord/TFRecordDeserializerTest.scala | 365 +++++++++----- .../tfrecord/TFRecordSerializerTest.scala | 452 ++++++++++++------ .../datasources/tfrecord/TestingUtils.scala | 396 ++++++++------- 5 files changed, 850 insertions(+), 441 deletions(-) diff --git a/pom.xml b/pom.xml index 9cc130f..ce6fb09 100644 --- a/pom.xml +++ b/pom.xml @@ -4,9 +4,9 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> 4.0.0 com.linkedin.sparktfrecord - spark-tfrecord_2.11 + spark-tfrecord_${scala.binary.version} jar - 0.2.2 + 0.2.3 spark-tfrecord https://github.com/linkedin/spark-tfrecord TensorFlow TFRecord data source for Apache Spark @@ -28,10 +28,8 @@ UTF-8 3.2.2 - 2.11 - 2.11.8 1.0 - 2.2.6 + 3.2.2 3.0 1.8 2.4.6 @@ -87,7 +85,7 @@ incremental true - ${scala.compiler.version} + ${scala.version} false @@ -173,6 +171,20 @@ + + + org.spurint.maven.plugins + scala-cross-maven-plugin + 0.2.1 + + + rewrite-pom + + rewrite-pom + + + + @@ -223,6 +235,10 @@ + + org.spurint.maven.plugins + scala-cross-maven-plugin + @@ -321,6 +337,23 @@ + + + scala-2.11 + + 2.11 + 2.11.8 + + + + + scala-2.12 + + 2.12 + 2.12.10 + + + diff --git a/src/test/scala/com/linkedin/spark/datasources/tfrecord/SharedSparkSessionSuite.scala b/src/test/scala/com/linkedin/spark/datasources/tfrecord/SharedSparkSessionSuite.scala index 72b2509..7f12ef5 100644 --- a/src/test/scala/com/linkedin/spark/datasources/tfrecord/SharedSparkSessionSuite.scala +++ b/src/test/scala/com/linkedin/spark/datasources/tfrecord/SharedSparkSessionSuite.scala @@ -1,18 +1,18 @@ /** - * Copyright 2016 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.linkedin.spark.datasources.tfrecord import java.io.File @@ -20,8 +20,8 @@ import java.io.File import org.apache.commons.io.FileUtils import org.apache.spark.SharedSparkSession import org.junit.{After, Before} -import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike} - +import org.scalatest._ +import matchers._ trait BaseSuite extends WordSpecLike with Matchers with BeforeAndAfterAll @@ -42,4 +42,3 @@ class SharedSparkSessionSuite extends SharedSparkSession with BaseSuite { super.tearDown() } } - diff --git a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala index 6899b2d..51250cf 100644 --- a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala +++ b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala @@ -1,18 +1,18 @@ /** - * Copyright 2016 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.linkedin.spark.datasources.tfrecord import com.google.protobuf.ByteString @@ -20,50 +20,114 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.scalatest.{Matchers, WordSpec} import org.tensorflow.example._ import TestingUtils._ - +import org.scalatest._ +import matchers._ class TFRecordDeserializerTest extends WordSpec with Matchers { - val intFeature = Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(1)).build() - val longFeature = Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(23L)).build() - val floatFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(10.0F)).build() - val doubleFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(14.0F)).build() - val decimalFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(2.5F)).build() - val longArrFeature = Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(-2L).addValue(7L).build()).build() - val doubleArrFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(1F).addValue(2F).build()).build() - val decimalArrFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(3F).addValue(5F).build()).build() - val strFeature = Feature.newBuilder().setBytesList(BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes)).build()).build() - val strListFeature =Feature.newBuilder().setBytesList(BytesList.newBuilder().addValue(ByteString.copyFrom("r2".getBytes)) - .addValue(ByteString.copyFrom("r3".getBytes)).build()).build() - val binaryFeature = Feature.newBuilder().setBytesList(BytesList.newBuilder().addValue(ByteString.copyFrom("r4".getBytes))).build() - val binaryListFeature = Feature.newBuilder().setBytesList(BytesList.newBuilder().addValue(ByteString.copyFrom("r5".getBytes)) - .addValue(ByteString.copyFrom("r6".getBytes)).build()).build() - - private def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray) + val intFeature = Feature + .newBuilder() + .setInt64List(Int64List.newBuilder().addValue(1)) + .build() + val longFeature = Feature + .newBuilder() + .setInt64List(Int64List.newBuilder().addValue(23L)) + .build() + val floatFeature = Feature + .newBuilder() + .setFloatList(FloatList.newBuilder().addValue(10.0F)) + .build() + val doubleFeature = Feature + .newBuilder() + .setFloatList(FloatList.newBuilder().addValue(14.0F)) + .build() + val decimalFeature = Feature + .newBuilder() + .setFloatList(FloatList.newBuilder().addValue(2.5F)) + .build() + val longArrFeature = Feature + .newBuilder() + .setInt64List(Int64List.newBuilder().addValue(-2L).addValue(7L).build()) + .build() + val doubleArrFeature = Feature + .newBuilder() + .setFloatList(FloatList.newBuilder().addValue(1F).addValue(2F).build()) + .build() + val decimalArrFeature = Feature + .newBuilder() + .setFloatList(FloatList.newBuilder().addValue(3F).addValue(5F).build()) + .build() + val strFeature = Feature + .newBuilder() + .setBytesList( + BytesList + .newBuilder() + .addValue(ByteString.copyFrom("r1".getBytes)) + .build() + ) + .build() + val strListFeature = Feature + .newBuilder() + .setBytesList( + BytesList + .newBuilder() + .addValue(ByteString.copyFrom("r2".getBytes)) + .addValue(ByteString.copyFrom("r3".getBytes)) + .build() + ) + .build() + val binaryFeature = Feature + .newBuilder() + .setBytesList( + BytesList.newBuilder().addValue(ByteString.copyFrom("r4".getBytes)) + ) + .build() + val binaryListFeature = Feature + .newBuilder() + .setBytesList( + BytesList + .newBuilder() + .addValue(ByteString.copyFrom("r5".getBytes)) + .addValue(ByteString.copyFrom("r6".getBytes)) + .build() + ) + .build() + + private def createArray(values: Any*): ArrayData = + new GenericArrayData(values.toArray) "Deserialize tfrecord to spark internalRow" should { "Serialize tfrecord example to spark internalRow" in { - val schema = StructType(List( - StructField("IntegerLabel", IntegerType), - StructField("LongLabel", LongType), - StructField("FloatLabel", FloatType), - StructField("DoubleLabel", DoubleType), - StructField("DecimalLabel", DataTypes.createDecimalType()), - StructField("LongArrayLabel", ArrayType(LongType)), - StructField("DoubleArrayLabel", ArrayType(DoubleType)), - StructField("DecimalArrayLabel", ArrayType(DataTypes.createDecimalType())), - StructField("StrLabel", StringType), - StructField("StrArrayLabel", ArrayType(StringType)), - StructField("BinaryTypeLabel", BinaryType), - StructField("BinaryTypeArrayLabel", ArrayType(BinaryType)) - )) + val schema = StructType( + List( + StructField("IntegerLabel", IntegerType), + StructField("LongLabel", LongType), + StructField("FloatLabel", FloatType), + StructField("DoubleLabel", DoubleType), + StructField("DecimalLabel", DataTypes.createDecimalType()), + StructField("LongArrayLabel", ArrayType(LongType)), + StructField("DoubleArrayLabel", ArrayType(DoubleType)), + StructField( + "DecimalArrayLabel", + ArrayType(DataTypes.createDecimalType()) + ), + StructField("StrLabel", StringType), + StructField("StrArrayLabel", ArrayType(StringType)), + StructField("BinaryTypeLabel", BinaryType), + StructField("BinaryTypeArrayLabel", ArrayType(BinaryType)) + ) + ) val expectedInternalRow = InternalRow.fromSeq( - Array[Any](1, 23L, 10.0F, 14.0, Decimal(2.5d), - createArray(-2L,7L), + Array[Any]( + 1, + 23L, + 10.0F, + 14.0, + Decimal(2.5d), + createArray(-2L, 7L), createArray(1.0, 2.0), createArray(Decimal(3.0), Decimal(5.0)), UTF8String.fromString("r1"), @@ -74,7 +138,8 @@ class TFRecordDeserializerTest extends WordSpec with Matchers { ) //Build example - val features = Features.newBuilder() + val features = Features + .newBuilder() .putFeature("IntegerLabel", intFeature) .putFeature("LongLabel", longFeature) .putFeature("FloatLabel", floatFeature) @@ -88,49 +153,82 @@ class TFRecordDeserializerTest extends WordSpec with Matchers { .putFeature("BinaryTypeLabel", binaryFeature) .putFeature("BinaryTypeArrayLabel", binaryListFeature) .build() - val example = Example.newBuilder() + val example = Example + .newBuilder() .setFeatures(features) .build() val deserializer = new TFRecordDeserializer(schema) val actualInternalRow = deserializer.deserializeExample(example) - assert(actualInternalRow ~== (expectedInternalRow,schema)) + assert(actualInternalRow ~== (expectedInternalRow, schema)) } "Serialize spark internalRow to tfrecord sequenceExample" in { - val schema = StructType(List( - StructField("FloatLabel", FloatType), - StructField("LongArrayOfArrayLabel", ArrayType(ArrayType(LongType))), - StructField("FloatArrayOfArrayLabel", ArrayType(ArrayType(FloatType))), - StructField("DecimalArrayOfArrayLabel", ArrayType(ArrayType(DataTypes.createDecimalType()))), - StructField("StrArrayOfArrayLabel", ArrayType(ArrayType(StringType))), - StructField("ByteArrayOfArrayLabel", ArrayType(ArrayType(BinaryType))) - )) + val schema = StructType( + List( + StructField("FloatLabel", FloatType), + StructField("LongArrayOfArrayLabel", ArrayType(ArrayType(LongType))), + StructField( + "FloatArrayOfArrayLabel", + ArrayType(ArrayType(FloatType)) + ), + StructField( + "DecimalArrayOfArrayLabel", + ArrayType(ArrayType(DataTypes.createDecimalType())) + ), + StructField("StrArrayOfArrayLabel", ArrayType(ArrayType(StringType))), + StructField("ByteArrayOfArrayLabel", ArrayType(ArrayType(BinaryType))) + ) + ) val expectedInternalRow = InternalRow.fromSeq( - Array[Any](10.0F, + Array[Any]( + 10.0F, createArray(createArray(-2L, 7L)), createArray(createArray(10.0F), createArray(1.0F, 2.0F)), createArray(createArray(Decimal(3.0), Decimal(5.0))), - createArray(createArray(UTF8String.fromString("r2"), UTF8String.fromString("r3")), - createArray(UTF8String.fromString("r1"))), - createArray(createArray("r5".getBytes, "r6".getBytes), createArray("r4".getBytes)) + createArray( + createArray( + UTF8String.fromString("r2"), + UTF8String.fromString("r3") + ), + createArray(UTF8String.fromString("r1")) + ), + createArray( + createArray("r5".getBytes, "r6".getBytes), + createArray("r4".getBytes) + ) ) ) //Build sequence example - val int64FeatureList = FeatureList.newBuilder().addFeature(longArrFeature).build() - val floatFeatureList = FeatureList.newBuilder().addFeature(floatFeature).addFeature(doubleArrFeature).build() - val decimalFeatureList = FeatureList.newBuilder().addFeature(decimalArrFeature).build() - val stringFeatureList = FeatureList.newBuilder().addFeature(strListFeature).addFeature(strFeature).build() - val binaryFeatureList = FeatureList.newBuilder().addFeature(binaryListFeature).addFeature(binaryFeature).build() - + val int64FeatureList = + FeatureList.newBuilder().addFeature(longArrFeature).build() + val floatFeatureList = FeatureList + .newBuilder() + .addFeature(floatFeature) + .addFeature(doubleArrFeature) + .build() + val decimalFeatureList = + FeatureList.newBuilder().addFeature(decimalArrFeature).build() + val stringFeatureList = FeatureList + .newBuilder() + .addFeature(strListFeature) + .addFeature(strFeature) + .build() + val binaryFeatureList = FeatureList + .newBuilder() + .addFeature(binaryListFeature) + .addFeature(binaryFeature) + .build() - val features = Features.newBuilder() + val features = Features + .newBuilder() .putFeature("FloatLabel", floatFeature) - val featureLists = FeatureLists.newBuilder() + val featureLists = FeatureLists + .newBuilder() .putFeatureList("LongArrayOfArrayLabel", int64FeatureList) .putFeatureList("FloatArrayOfArrayLabel", floatFeatureList) .putFeatureList("DecimalArrayOfArrayLabel", decimalFeatureList) @@ -138,24 +236,29 @@ class TFRecordDeserializerTest extends WordSpec with Matchers { .putFeatureList("ByteArrayOfArrayLabel", binaryFeatureList) .build() - val seqExample = SequenceExample.newBuilder() + val seqExample = SequenceExample + .newBuilder() .setContext(features) .setFeatureLists(featureLists) .build() val deserializer = new TFRecordDeserializer(schema) - val actualInternalRow = deserializer.deserializeSequenceExample(seqExample) + val actualInternalRow = + deserializer.deserializeSequenceExample(seqExample) assert(actualInternalRow ~== (expectedInternalRow, schema)) } "Throw an exception for unsupported data types" in { val features = Features.newBuilder().putFeature("MapLabel1", intFeature) - val int64FeatureList = FeatureList.newBuilder().addFeature(longArrFeature).build() - val featureLists = FeatureLists.newBuilder().putFeatureList("MapLabel2", int64FeatureList) + val int64FeatureList = + FeatureList.newBuilder().addFeature(longArrFeature).build() + val featureLists = + FeatureLists.newBuilder().putFeatureList("MapLabel2", int64FeatureList) intercept[RuntimeException] { - val example = Example.newBuilder() + val example = Example + .newBuilder() .setFeatures(features) .build() val schema = StructType(List(StructField("MapLabel1", TimestampType))) @@ -164,7 +267,8 @@ class TFRecordDeserializerTest extends WordSpec with Matchers { } intercept[RuntimeException] { - val seqExample = SequenceExample.newBuilder() + val seqExample = SequenceExample + .newBuilder() .setContext(features) .setFeatureLists(featureLists) .build() @@ -175,45 +279,65 @@ class TFRecordDeserializerTest extends WordSpec with Matchers { } "Throw an exception for non-nullable data types" in { - val features = Features.newBuilder().putFeature("FloatLabel", floatFeature) - val int64FeatureList = FeatureList.newBuilder().addFeature(longArrFeature).build() - val featureLists = FeatureLists.newBuilder().putFeatureList("LongArrayOfArrayLabel", int64FeatureList) + val features = + Features.newBuilder().putFeature("FloatLabel", floatFeature) + val int64FeatureList = + FeatureList.newBuilder().addFeature(longArrFeature).build() + val featureLists = FeatureLists + .newBuilder() + .putFeatureList("LongArrayOfArrayLabel", int64FeatureList) intercept[NullPointerException] { - val example = Example.newBuilder() + val example = Example + .newBuilder() .setFeatures(features) .build() - val schema = StructType(List(StructField("MissingLabel", FloatType, nullable = false))) + val schema = StructType( + List(StructField("MissingLabel", FloatType, nullable = false)) + ) val deserializer = new TFRecordDeserializer(schema) deserializer.deserializeExample(example) } intercept[NullPointerException] { - val seqExample = SequenceExample.newBuilder() + val seqExample = SequenceExample + .newBuilder() .setContext(features) .setFeatureLists(featureLists) .build() - val schema = StructType(List(StructField("MissingLabel", ArrayType(ArrayType(LongType)), nullable = false))) + val schema = StructType( + List( + StructField( + "MissingLabel", + ArrayType(ArrayType(LongType)), + nullable = false + ) + ) + ) val deserializer = new TFRecordDeserializer(schema) deserializer.deserializeSequenceExample(seqExample) } } - "Return null fields for nullable data types" in { - val features = Features.newBuilder().putFeature("FloatLabel", floatFeature) - val int64FeatureList = FeatureList.newBuilder().addFeature(longArrFeature).build() - val featureLists = FeatureLists.newBuilder().putFeatureList("LongArrayOfArrayLabel", int64FeatureList) + val features = + Features.newBuilder().putFeature("FloatLabel", floatFeature) + val int64FeatureList = + FeatureList.newBuilder().addFeature(longArrFeature).build() + val featureLists = FeatureLists + .newBuilder() + .putFeatureList("LongArrayOfArrayLabel", int64FeatureList) // Deserialize Example - val schema1 = StructType(List( - StructField("FloatLabel", FloatType), - StructField("MissingLabel", FloatType, nullable = true)) - ) - val expectedInternalRow1 = InternalRow.fromSeq( - Array[Any](10.0F, null) + val schema1 = StructType( + List( + StructField("FloatLabel", FloatType), + StructField("MissingLabel", FloatType, nullable = true) + ) ) - val example = Example.newBuilder() + val expectedInternalRow1 = InternalRow.fromSeq(Array[Any](10.0F, null)) + val example = Example + .newBuilder() .setFeatures(features) .build() val deserializer1 = new TFRecordDeserializer(schema1) @@ -221,27 +345,31 @@ class TFRecordDeserializerTest extends WordSpec with Matchers { assert(actualInternalRow1 ~== (expectedInternalRow1, schema1)) // Deserialize SequenceExample - val schema2 = StructType(List( - StructField("LongArrayOfArrayLabel", ArrayType(ArrayType(LongType))), - StructField("MissingLabel", ArrayType(ArrayType(LongType)), nullable = true)) - ) - val expectedInternalRow2 = InternalRow.fromSeq( - Array[Any]( - createArray(createArray(-2L, 7L)), null) + val schema2 = StructType( + List( + StructField("LongArrayOfArrayLabel", ArrayType(ArrayType(LongType))), + StructField( + "MissingLabel", + ArrayType(ArrayType(LongType)), + nullable = true + ) + ) ) - val seqExample = SequenceExample.newBuilder() + val expectedInternalRow2 = + InternalRow.fromSeq(Array[Any](createArray(createArray(-2L, 7L)), null)) + val seqExample = SequenceExample + .newBuilder() .setContext(features) .setFeatureLists(featureLists) .build() val deserializer2 = new TFRecordDeserializer(schema2) - val actualInternalRow2 = deserializer2.deserializeSequenceExample(seqExample) + val actualInternalRow2 = + deserializer2.deserializeSequenceExample(seqExample) assert(actualInternalRow2 ~== (expectedInternalRow2, schema2)) } - val schema = StructType(Array( - StructField("LongLabel", LongType)) - ) + val schema = StructType(Array(StructField("LongLabel", LongType))) val deserializer = new TFRecordDeserializer(schema) "Test Int64ListFeature2SeqLong" in { @@ -271,9 +399,16 @@ class TFRecordDeserializerTest extends WordSpec with Matchers { } "Test bytesListFeature2SeqArrayByte" in { - val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build() + val bytesList = BytesList + .newBuilder() + .addValue(ByteString.copyFrom("str-input".getBytes)) + .build() val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build() - assert(deserializer.bytesListFeature2SeqArrayByte(bytesFeature).head === "str-input".getBytes.deep) + assert( + deserializer + .bytesListFeature2SeqArrayByte(bytesFeature) + .head === "str-input".getBytes.deep + ) // Throw exception if type doesn't match intercept[RuntimeException] { @@ -284,10 +419,16 @@ class TFRecordDeserializerTest extends WordSpec with Matchers { } "Test bytesListFeature2SeqString" in { - val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("alice".getBytes)) - .addValue(ByteString.copyFrom("bob".getBytes)).build() + val bytesList = BytesList + .newBuilder() + .addValue(ByteString.copyFrom("alice".getBytes)) + .addValue(ByteString.copyFrom("bob".getBytes)) + .build() val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build() - assert(deserializer.bytesListFeature2SeqString(bytesFeature) === Seq("alice", "bob")) + assert( + deserializer + .bytesListFeature2SeqString(bytesFeature) === Seq("alice", "bob") + ) // Throw exception if type doesn't match intercept[RuntimeException] { diff --git a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala index d1671a5..47b9e00 100644 --- a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala +++ b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala @@ -1,18 +1,18 @@ /** - * Copyright 2016 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.linkedin.spark.datasources.tfrecord import org.apache.spark.sql.catalyst.InternalRow @@ -20,26 +20,34 @@ import org.tensorflow.example._ import org.apache.spark.sql.types.{StructField, _} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.unsafe.types.UTF8String -import org.scalatest.{Matchers, WordSpec} +import org.scalatest._ +import matchers._ import scala.collection.JavaConverters._ import TestingUtils._ class TFRecordSerializerTest extends WordSpec with Matchers { - private def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray) + private def createArray(values: Any*): ArrayData = + new GenericArrayData(values.toArray) "Serialize spark internalRow to tfrecord" should { "Serialize decimal internalRow to tfrecord example" in { - val schemaStructType = StructType(Array( - StructField("DecimalLabel", DataTypes.createDecimalType()), - StructField("DecimalArrayLabel", ArrayType(DataTypes.createDecimalType())) - )) + val schemaStructType = StructType( + Array( + StructField("DecimalLabel", DataTypes.createDecimalType()), + StructField( + "DecimalArrayLabel", + ArrayType(DataTypes.createDecimalType()) + ) + ) + ) val serializer = new TFRecordSerializer(schemaStructType) val decimalArray = Array(Decimal(4.0), Decimal(8.0)) - val rowArray = Array[Any](Decimal(6.5), ArrayData.toArrayData(decimalArray)) + val rowArray = + Array[Any](Decimal(6.5), ArrayData.toArrayData(decimalArray)) val internalRow = InternalRow.fromSeq(rowArray) //Encode Sql InternalRow to TensorFlow example @@ -49,37 +57,57 @@ class TFRecordSerializerTest extends WordSpec with Matchers { val featureMap = example.getFeatures.getFeatureMap.asScala assert(featureMap.size == rowArray.length) - assert(featureMap("DecimalLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) + assert( + featureMap("DecimalLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER + ) assert(featureMap("DecimalLabel").getFloatList.getValue(0) == 6.5F) - assert(featureMap("DecimalArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) - assert(featureMap("DecimalArrayLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== decimalArray.map(_.toFloat)) + assert( + featureMap("DecimalArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER + ) + assert( + featureMap("DecimalArrayLabel").getFloatList.getValueList.asScala.toSeq + .map(_.toFloat) ~== decimalArray.map(_.toFloat) + ) } "Serialize complex internalRow to tfrecord example" in { - val schemaStructType = StructType(Array( - StructField("IntegerLabel", IntegerType), - StructField("LongLabel", LongType), - StructField("FloatLabel", FloatType), - StructField("DoubleLabel", DoubleType), - StructField("DecimalLabel", DataTypes.createDecimalType()), - StructField("DoubleArrayLabel", ArrayType(DoubleType)), - StructField("DecimalArrayLabel", ArrayType(DataTypes.createDecimalType())), - StructField("StrLabel", StringType), - StructField("StrArrayLabel", ArrayType(StringType)), - StructField("BinaryLabel", BinaryType), - StructField("BinaryArrayLabel", ArrayType(BinaryType)) - )) + val schemaStructType = StructType( + Array( + StructField("IntegerLabel", IntegerType), + StructField("LongLabel", LongType), + StructField("FloatLabel", FloatType), + StructField("DoubleLabel", DoubleType), + StructField("DecimalLabel", DataTypes.createDecimalType()), + StructField("DoubleArrayLabel", ArrayType(DoubleType)), + StructField( + "DecimalArrayLabel", + ArrayType(DataTypes.createDecimalType()) + ), + StructField("StrLabel", StringType), + StructField("StrArrayLabel", ArrayType(StringType)), + StructField("BinaryLabel", BinaryType), + StructField("BinaryArrayLabel", ArrayType(BinaryType)) + ) + ) val doubleArray = Array(1.1, 111.1, 11111.1) val decimalArray = Array(Decimal(4.0), Decimal(8.0)) - val byteArray = Array[Byte](0xde.toByte, 0xad.toByte, 0xbe.toByte, 0xef.toByte) + val byteArray = + Array[Byte](0xde.toByte, 0xad.toByte, 0xbe.toByte, 0xef.toByte) val byteArray1 = Array[Byte](-128, 23, 127) - val rowArray = Array[Any](1, 23L, 10.0F, 14.0, Decimal(6.5), + val rowArray = Array[Any]( + 1, + 23L, + 10.0F, + 14.0, + Decimal(6.5), ArrayData.toArrayData(doubleArray), ArrayData.toArrayData(decimalArray), UTF8String.fromString("r1"), - ArrayData.toArrayData(Array(UTF8String.fromString("r2"), UTF8String.fromString("r3"))), + ArrayData.toArrayData( + Array(UTF8String.fromString("r2"), UTF8String.fromString("r3")) + ), byteArray, ArrayData.toArrayData(Array(byteArray, byteArray1)) ) @@ -93,78 +121,139 @@ class TFRecordSerializerTest extends WordSpec with Matchers { val featureMap = example.getFeatures.getFeatureMap.asScala assert(featureMap.size == rowArray.length) - assert(featureMap("IntegerLabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER) + assert( + featureMap("IntegerLabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER + ) assert(featureMap("IntegerLabel").getInt64List.getValue(0).toInt == 1) - assert(featureMap("LongLabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER) + assert( + featureMap("LongLabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER + ) assert(featureMap("LongLabel").getInt64List.getValue(0).toInt == 23) - assert(featureMap("FloatLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) + assert( + featureMap("FloatLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER + ) assert(featureMap("FloatLabel").getFloatList.getValue(0) == 10.0F) - assert(featureMap("DoubleLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) + assert( + featureMap("DoubleLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER + ) assert(featureMap("DoubleLabel").getFloatList.getValue(0) == 14.0F) - assert(featureMap("DecimalLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) + assert( + featureMap("DecimalLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER + ) assert(featureMap("DecimalLabel").getFloatList.getValue(0) == 6.5F) - assert(featureMap("DoubleArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) - assert(featureMap("DoubleArrayLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== doubleArray.map(_.toFloat)) + assert( + featureMap("DoubleArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER + ) + assert( + featureMap("DoubleArrayLabel").getFloatList.getValueList.asScala.toSeq + .map(_.toFloat) ~== doubleArray.map(_.toFloat) + ) - assert(featureMap("DecimalArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) - assert(featureMap("DecimalArrayLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== decimalArray.map(_.toFloat)) + assert( + featureMap("DecimalArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER + ) + assert( + featureMap("DecimalArrayLabel").getFloatList.getValueList.asScala.toSeq + .map(_.toFloat) ~== decimalArray.map(_.toFloat) + ) - assert(featureMap("StrLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER) - assert(featureMap("StrLabel").getBytesList.getValue(0).toStringUtf8 == "r1") + assert( + featureMap("StrLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER + ) + assert( + featureMap("StrLabel").getBytesList.getValue(0).toStringUtf8 == "r1" + ) - assert(featureMap("StrArrayLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER) - assert(featureMap("StrArrayLabel").getBytesList.getValueList.asScala.map(_.toStringUtf8) === Seq("r2", "r3")) + assert( + featureMap("StrArrayLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER + ) + assert( + featureMap("StrArrayLabel").getBytesList.getValueList.asScala + .map(_.toStringUtf8) === Seq("r2", "r3") + ) - assert(featureMap("BinaryLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER) - assert(featureMap("BinaryLabel").getBytesList.getValue(0).toByteArray.deep == byteArray.deep) + assert( + featureMap("BinaryLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER + ) + assert( + featureMap("BinaryLabel").getBytesList + .getValue(0) + .toByteArray + .deep == byteArray.deep + ) - assert(featureMap("BinaryArrayLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER) - val binaryArrayValue = featureMap("BinaryArrayLabel").getBytesList.getValueList.asScala.map((byteArray) => byteArray.asScala.toArray.map(_.toByte)) + assert( + featureMap("BinaryArrayLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER + ) + val binaryArrayValue = + featureMap("BinaryArrayLabel").getBytesList.getValueList.asScala + .map((byteArray) => byteArray.asScala.toArray.map(_.toByte)) assert(binaryArrayValue.toArray.deep == Array(byteArray, byteArray1).deep) } "Serialize internalRow to tfrecord sequenceExample" in { - val schemaStructType = StructType(Array( - StructField("IntegerLabel", IntegerType), - StructField("StringArrayLabel", ArrayType(StringType)), - StructField("LongArrayOfArrayLabel", ArrayType(ArrayType(LongType))), - StructField("FloatArrayOfArrayLabel", ArrayType(ArrayType(FloatType))) , - StructField("DoubleArrayOfArrayLabel", ArrayType(ArrayType(DoubleType))), - StructField("DecimalArrayOfArrayLabel", ArrayType(ArrayType(DataTypes.createDecimalType()))), - StructField("StringArrayOfArrayLabel", ArrayType(ArrayType(StringType))), - StructField("BinaryArrayOfArrayLabel", ArrayType(ArrayType(BinaryType))) - )) - - val stringList = Array(UTF8String.fromString("r1"), UTF8String.fromString("r2"), UTF8String.fromString(("r3"))) + val schemaStructType = StructType( + Array( + StructField("IntegerLabel", IntegerType), + StructField("StringArrayLabel", ArrayType(StringType)), + StructField("LongArrayOfArrayLabel", ArrayType(ArrayType(LongType))), + StructField( + "FloatArrayOfArrayLabel", + ArrayType(ArrayType(FloatType)) + ), + StructField( + "DoubleArrayOfArrayLabel", + ArrayType(ArrayType(DoubleType)) + ), + StructField( + "DecimalArrayOfArrayLabel", + ArrayType(ArrayType(DataTypes.createDecimalType())) + ), + StructField( + "StringArrayOfArrayLabel", + ArrayType(ArrayType(StringType)) + ), + StructField( + "BinaryArrayOfArrayLabel", + ArrayType(ArrayType(BinaryType)) + ) + ) + ) + + val stringList = Array( + UTF8String.fromString("r1"), + UTF8String.fromString("r2"), + UTF8String.fromString(("r3")) + ) val longListOfLists = Array(Array(3L, 5L), Array(-8L, 0L)) val floatListOfLists = Array(Array(1.5F, -6.5F), Array(-8.2F, 0F)) val doubleListOfLists = Array(Array(3.0), Array(6.0, 9.0)) - val decimalListOfLists = Array(Array(Decimal(2.0), Decimal(4.0)), Array(Decimal(6.0))) - val stringListOfLists = Array(Array(UTF8String.fromString("r1")), + val decimalListOfLists = + Array(Array(Decimal(2.0), Decimal(4.0)), Array(Decimal(6.0))) + val stringListOfLists = Array( + Array(UTF8String.fromString("r1")), Array(UTF8String.fromString("r2"), UTF8String.fromString("r3")), - Array(UTF8String.fromString("r4"))) - val binaryListOfLists = stringListOfLists.map(stringList => stringList.map(_.getBytes)) + Array(UTF8String.fromString("r4")) + ) + val binaryListOfLists = + stringListOfLists.map(stringList => stringList.map(_.getBytes)) - val rowArray = Array[Any](10, - createArray(UTF8String.fromString("r1"), UTF8String.fromString("r2"), UTF8String.fromString(("r3"))), - createArray( - createArray(3L, 5L), - createArray(-8L, 0L) - ), + val rowArray = Array[Any]( + 10, createArray( - createArray(1.5F, -6.5F), - createArray(-8.2F, 0F) - ), - createArray( - createArray(3.0), - createArray(6.0, 9.0) + UTF8String.fromString("r1"), + UTF8String.fromString("r2"), + UTF8String.fromString(("r3")) ), + createArray(createArray(3L, 5L), createArray(-8L, 0L)), + createArray(createArray(1.5F, -6.5F), createArray(-8.2F, 0F)), + createArray(createArray(3.0), createArray(6.0, 9.0)), createArray( createArray(Decimal(2.0), Decimal(4.0)), createArray(Decimal(6.0)) @@ -174,7 +263,8 @@ class TFRecordSerializerTest extends WordSpec with Matchers { createArray(UTF8String.fromString("r2"), UTF8String.fromString("r3")), createArray(UTF8String.fromString("r4")) ), - createArray(createArray("r1".getBytes()), + createArray( + createArray("r1".getBytes()), createArray("r2".getBytes(), "r3".getBytes), createArray("r4".getBytes()) ) @@ -190,58 +280,96 @@ class TFRecordSerializerTest extends WordSpec with Matchers { val featureListMap = tfexample.getFeatureLists.getFeatureListMap.asScala assert(featureMap.size == 2) - assert(featureMap("IntegerLabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER) + assert( + featureMap("IntegerLabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER + ) assert(featureMap("IntegerLabel").getInt64List.getValue(0).toInt == 10) - assert(featureMap("StringArrayLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER) - assert(featureMap("StringArrayLabel").getBytesList.getValueList.asScala.map(x => UTF8String.fromString(x.toStringUtf8)) === stringList) + assert( + featureMap("StringArrayLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER + ) + assert( + featureMap("StringArrayLabel").getBytesList.getValueList.asScala + .map(x => UTF8String.fromString(x.toStringUtf8)) === stringList + ) assert(featureListMap.size == 6) - assert(featureListMap("LongArrayOfArrayLabel").getFeatureList.asScala.map( - _.getInt64List.getValueList.asScala.toSeq) === longListOfLists) + assert( + featureListMap("LongArrayOfArrayLabel").getFeatureList.asScala + .map(_.getInt64List.getValueList.asScala.toSeq) === longListOfLists + ) - assert(featureListMap("FloatArrayOfArrayLabel").getFeatureList.asScala.map( - _.getFloatList.getValueList.asScala.map(_.toFloat).toSeq) ~== floatListOfLists.map{arr => arr.toSeq}.toSeq) - assert(featureListMap("DoubleArrayOfArrayLabel").getFeatureList.asScala.map( - _.getFloatList.getValueList.asScala.map(_.toDouble).toSeq) ~== doubleListOfLists.map{arr => arr.toSeq}.toSeq) + assert( + featureListMap("FloatArrayOfArrayLabel").getFeatureList.asScala.map( + _.getFloatList.getValueList.asScala.map(_.toFloat).toSeq + ) ~== floatListOfLists.map { arr => + arr.toSeq + }.toSeq + ) + assert( + featureListMap("DoubleArrayOfArrayLabel").getFeatureList.asScala.map( + _.getFloatList.getValueList.asScala.map(_.toDouble).toSeq + ) ~== doubleListOfLists.map { arr => + arr.toSeq + }.toSeq + ) - assert(featureListMap("DecimalArrayOfArrayLabel").getFeatureList.asScala.map( - _.getFloatList.getValueList.asScala.map(x => Decimal(x.toDouble)).toSeq) ~== decimalListOfLists.map{arr => arr.toSeq}.toSeq) + assert( + featureListMap("DecimalArrayOfArrayLabel").getFeatureList.asScala.map( + _.getFloatList.getValueList.asScala + .map(x => Decimal(x.toDouble)) + .toSeq + ) ~== decimalListOfLists.map { arr => + arr.toSeq + }.toSeq + ) - assert(featureListMap("StringArrayOfArrayLabel").getFeatureList.asScala.map( - _.getBytesList.getValueList.asScala.map(x => UTF8String.fromString(x.toStringUtf8)).toSeq) === stringListOfLists) + assert( + featureListMap("StringArrayOfArrayLabel").getFeatureList.asScala.map( + _.getBytesList.getValueList.asScala + .map(x => UTF8String.fromString(x.toStringUtf8)) + .toSeq + ) === stringListOfLists + ) - assert(featureListMap("BinaryArrayOfArrayLabel").getFeatureList.asScala.map( - _.getBytesList.getValueList.asScala.map(byteList => byteList.asScala.toSeq)) === binaryListOfLists.map(_.map(_.toSeq))) + assert( + featureListMap("BinaryArrayOfArrayLabel").getFeatureList.asScala.map( + _.getBytesList.getValueList.asScala + .map(byteList => byteList.asScala.toSeq) + ) === binaryListOfLists.map(_.map(_.toSeq)) + ) } "Throw an exception for non-nullable data types" in { - val schemaStructType = StructType(Array( - StructField("NonNullLabel", ArrayType(FloatType), nullable = false) - )) + val schemaStructType = StructType( + Array( + StructField("NonNullLabel", ArrayType(FloatType), nullable = false) + ) + ) val internalRow = InternalRow.fromSeq(Array[Any](null)) val serializer = new TFRecordSerializer(schemaStructType) - intercept[NullPointerException]{ + intercept[NullPointerException] { serializer.serializeExample(internalRow) } - intercept[NullPointerException]{ + intercept[NullPointerException] { serializer.serializeSequenceExample(internalRow) } } "Omit null fields from Example for nullable data types" in { - val schemaStructType = StructType(Array( - StructField("NullLabel", ArrayType(FloatType), nullable = true), - StructField("FloatArrayLabel", ArrayType(FloatType)) - )) + val schemaStructType = StructType( + Array( + StructField("NullLabel", ArrayType(FloatType), nullable = true), + StructField("FloatArrayLabel", ArrayType(FloatType)) + ) + ) val floatArray = Array(2.5F, 5.0F) - val internalRow = InternalRow.fromSeq( - Array[Any](null, createArray(2.5F, 5.0F)) - ) + val internalRow = + InternalRow.fromSeq(Array[Any](null, createArray(2.5F, 5.0F))) val serializer = new TFRecordSerializer(schemaStructType) val tfexample = serializer.serializeExample(internalRow) @@ -249,74 +377,108 @@ class TFRecordSerializerTest extends WordSpec with Matchers { //Verify each Datatype converted to TensorFlow datatypes val featureMap = tfexample.getFeatures.getFeatureMap.asScala assert(featureMap.size == 1) - assert(featureMap("FloatArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) - assert(featureMap("FloatArrayLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== floatArray.toSeq) + assert( + featureMap("FloatArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER + ) + assert( + featureMap("FloatArrayLabel").getFloatList.getValueList.asScala.toSeq + .map(_.toFloat) ~== floatArray.toSeq + ) } "Omit null fields from SequenceExample for nullable data types" in { - val schemaStructType = StructType(Array( - StructField("NullLabel", ArrayType(FloatType), nullable = true), - StructField("FloatArrayLabel", ArrayType(FloatType)) - )) + val schemaStructType = StructType( + Array( + StructField("NullLabel", ArrayType(FloatType), nullable = true), + StructField("FloatArrayLabel", ArrayType(FloatType)) + ) + ) val floatArray = Array(2.5F, 5.0F) - val internalRow = InternalRow.fromSeq( - Array[Any](null, createArray(2.5F, 5.0F))) + val internalRow = + InternalRow.fromSeq(Array[Any](null, createArray(2.5F, 5.0F))) val serializer = new TFRecordSerializer(schemaStructType) val tfSeqExample = serializer.serializeSequenceExample(internalRow) //Verify each Datatype converted to TensorFlow datatypes val featureMap = tfSeqExample.getContext.getFeatureMap.asScala - val featureListMap = tfSeqExample.getFeatureLists.getFeatureListMap.asScala + val featureListMap = + tfSeqExample.getFeatureLists.getFeatureListMap.asScala assert(featureMap.size == 1) assert(featureListMap.isEmpty) - assert(featureMap("FloatArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) - assert(featureMap("FloatArrayLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== floatArray.toSeq) + assert( + featureMap("FloatArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER + ) + assert( + featureMap("FloatArrayLabel").getFloatList.getValueList.asScala.toSeq + .map(_.toFloat) ~== floatArray.toSeq + ) } "Throw an exception for unsupported data types" in { - val schemaStructType = StructType(Array( - StructField("TimestampLabel", TimestampType) - )) + val schemaStructType = + StructType(Array(StructField("TimestampLabel", TimestampType))) - intercept[RuntimeException]{ + intercept[RuntimeException] { new TFRecordSerializer(schemaStructType) } } - val schema = StructType(Array( - StructField("bytesLabel", BinaryType)) - ) + val schema = StructType(Array(StructField("bytesLabel", BinaryType))) val serializer = new TFRecordSerializer(schema) "Test Int64ListFeature" in { val longFeature = serializer.Int64ListFeature(Seq(10L)) - val longListFeature = serializer.Int64ListFeature(Seq(3L,5L,6L)) + val longListFeature = serializer.Int64ListFeature(Seq(3L, 5L, 6L)) assert(longFeature.getInt64List.getValueList.asScala.toSeq === Seq(10L)) - assert(longListFeature.getInt64List.getValueList.asScala.toSeq === Seq(3L, 5L, 6L)) + assert( + longListFeature.getInt64List.getValueList.asScala.toSeq === Seq( + 3L, + 5L, + 6L + ) + ) } "Test floatListFeature" in { val floatFeature = serializer.floatListFeature(Seq(10.1F)) - val floatListFeature = serializer.floatListFeature(Seq(3.1F,5.1F,6.1F)) + val floatListFeature = serializer.floatListFeature(Seq(3.1F, 5.1F, 6.1F)) - assert(floatFeature.getFloatList.getValueList.asScala.toSeq === Seq(10.1F)) - assert(floatListFeature.getFloatList.getValueList.asScala.toSeq === Seq(3.1F,5.1F,6.1F)) + assert( + floatFeature.getFloatList.getValueList.asScala.toSeq === Seq(10.1F) + ) + assert( + floatListFeature.getFloatList.getValueList.asScala.toSeq === Seq( + 3.1F, + 5.1F, + 6.1F + ) + ) } "Test bytesListFeature" in { - val bytesFeature = serializer.bytesListFeature(Seq(Array(0xff.toByte, 0xd8.toByte))) - val bytesListFeature = serializer.bytesListFeature(Seq( - Array(0xff.toByte, 0xd8.toByte), - Array(0xff.toByte, 0xd9.toByte))) - - assert(bytesFeature.getBytesList.getValueList.asScala.map(_.toByteArray.deep) === - Seq(Array(0xff.toByte, 0xd8.toByte).deep)) - assert(bytesListFeature.getBytesList.getValueList.asScala.map(_.toByteArray.deep) === - Seq(Array(0xff.toByte, 0xd8.toByte).deep, Array(0xff.toByte, 0xd9.toByte).deep)) + val bytesFeature = + serializer.bytesListFeature(Seq(Array(0xff.toByte, 0xd8.toByte))) + val bytesListFeature = serializer.bytesListFeature( + Seq(Array(0xff.toByte, 0xd8.toByte), Array(0xff.toByte, 0xd9.toByte)) + ) + + assert( + bytesFeature.getBytesList.getValueList.asScala + .map(_.toByteArray.deep) === + Seq(Array(0xff.toByte, 0xd8.toByte).deep) + ) + assert( + bytesListFeature.getBytesList.getValueList.asScala + .map(_.toByteArray.deep) === + Seq( + Array(0xff.toByte, 0xd8.toByte).deep, + Array(0xff.toByte, 0xd9.toByte).deep + ) + ) } } } diff --git a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TestingUtils.scala b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TestingUtils.scala index 4a5bbea..5d9380a 100644 --- a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TestingUtils.scala +++ b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TestingUtils.scala @@ -1,122 +1,122 @@ /** - * Copyright 2016 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.linkedin.spark.datasources.tfrecord import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, SpecializedGetters} +import org.apache.spark.sql.catalyst.expressions.{ + GenericRowWithSchema, + SpecializedGetters +} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ -import org.scalatest.Matchers +import org.scalatest._ +import matchers._ object TestingUtils extends Matchers { /** - * Implicit class for comparing two double values using absolute tolerance. - */ + * Implicit class for comparing two double values using absolute tolerance. + */ implicit class FloatArrayWithAlmostEquals(val left: Seq[Float]) { /** - * When the difference of two values are within eps, returns true; otherwise, returns false. - */ - def ~==(right: Seq[Float], epsilon : Float = 1E-6F): Boolean = { + * When the difference of two values are within eps, returns true; otherwise, returns false. + */ + def ~==(right: Seq[Float], epsilon: Float = 1E-6F): Boolean = { if (left.size === right.size) { (left zip right) forall { case (a, b) => a === (b +- epsilon) } - } - else false + } else false } } /** - * Implicit class for comparing two double values using absolute tolerance. - */ + * Implicit class for comparing two double values using absolute tolerance. + */ implicit class DoubleArrayWithAlmostEquals(val left: Seq[Double]) { /** - * When the difference of two values are within eps, returns true; otherwise, returns false. - */ - def ~==(right: Seq[Double], epsilon : Double = 1E-6): Boolean = { + * When the difference of two values are within eps, returns true; otherwise, returns false. + */ + def ~==(right: Seq[Double], epsilon: Double = 1E-6): Boolean = { if (left.size === right.size) { (left zip right) forall { case (a, b) => a === (b +- epsilon) } - } - else false + } else false } } /** - * Implicit class for comparing two decimal values using absolute tolerance. - */ + * Implicit class for comparing two decimal values using absolute tolerance. + */ implicit class DecimalArrayWithAlmostEquals(val left: Seq[Decimal]) { /** - * When the difference of two values are within eps, returns true; otherwise, returns false. - */ - def ~==(right: Seq[Decimal], epsilon : Double = 1E-6): Boolean = { + * When the difference of two values are within eps, returns true; otherwise, returns false. + */ + def ~==(right: Seq[Decimal], epsilon: Double = 1E-6): Boolean = { if (left.size === right.size) { - (left zip right) forall { case (a, b) => a.toDouble === (b.toDouble +- epsilon) } - } - else false + (left zip right) forall { + case (a, b) => a.toDouble === (b.toDouble +- epsilon) + } + } else false } } /** - * Implicit class for comparing two double values using absolute tolerance. - */ + * Implicit class for comparing two double values using absolute tolerance. + */ implicit class FloatMatrixWithAlmostEquals(val left: Seq[Seq[Float]]) { /** - * When the difference of two values are within eps, returns true; otherwise, returns false. - */ - def ~==(right: Seq[Seq[Float]], epsilon : Float = 1E-6F): Boolean = { + * When the difference of two values are within eps, returns true; otherwise, returns false. + */ + def ~==(right: Seq[Seq[Float]], epsilon: Float = 1E-6F): Boolean = { if (left.size === right.size) { (left zip right) forall { case (a, b) => a ~== (b, epsilon) } - } - else false + } else false } } /** - * Implicit class for comparing two double values using absolute tolerance. - */ + * Implicit class for comparing two double values using absolute tolerance. + */ implicit class DoubleMatrixWithAlmostEquals(val left: Seq[Seq[Double]]) { /** - * When the difference of two values are within eps, returns true; otherwise, returns false. - */ - def ~==(right: Seq[Seq[Double]], epsilon : Double = 1E-6): Boolean = { + * When the difference of two values are within eps, returns true; otherwise, returns false. + */ + def ~==(right: Seq[Seq[Double]], epsilon: Double = 1E-6): Boolean = { if (left.size === right.size) { (left zip right) forall { case (a, b) => a ~== (b, epsilon) } - } - else false + } else false } } /** - * Implicit class for comparing two decimal values using absolute tolerance. - */ + * Implicit class for comparing two decimal values using absolute tolerance. + */ implicit class DecimalMatrixWithAlmostEquals(val left: Seq[Seq[Decimal]]) { /** - * When the difference of two values are within eps, returns true; otherwise, returns false. - */ - def ~==(right: Seq[Seq[Decimal]], epsilon : Double = 1E-6): Boolean = { + * When the difference of two values are within eps, returns true; otherwise, returns false. + */ + def ~==(right: Seq[Seq[Decimal]], epsilon: Double = 1E-6): Boolean = { if (left.size === right.size) { (left zip right) forall { case (a, b) => a ~== (b, epsilon) } - } - else false + } else false } } @@ -125,93 +125,122 @@ object TestingUtils extends Matchers { */ implicit class InternalRowWithAlmostEquals(val left: InternalRow) { - private type valueCompare = (SpecializedGetters, SpecializedGetters, Int) => Boolean - private def newValueCompare( - dataType: DataType, - epsilon : Float = 1E-6F): valueCompare = dataType match { - case NullType => (left, right, ordinal) => - left.get(ordinal, null) == right.get(ordinal, null) - - case IntegerType => (left, right, ordinal) => - left.getInt(ordinal) === right.getInt(ordinal) - - case LongType => (left, right, ordinal) => - left.getLong(ordinal) === right.getLong(ordinal) - - case FloatType => (left, right, ordinal) => - left.getFloat(ordinal) === (right.getFloat(ordinal) +- epsilon) - - case DoubleType => (left, right, ordinal) => - left.getDouble(ordinal) === (right.getDouble(ordinal) +- epsilon) - - case DecimalType() => (left, right, ordinal) => - left.getDecimal(ordinal, DecimalType.USER_DEFAULT.precision, DecimalType.USER_DEFAULT.scale).toDouble === - (right.getDecimal(ordinal, DecimalType.USER_DEFAULT.precision, DecimalType.USER_DEFAULT.scale).toDouble - +- epsilon) - - case StringType => (left, right, ordinal) => - left.getUTF8String(ordinal).getBytes === right.getUTF8String(ordinal).getBytes - - case BinaryType => (left, right, ordinal) => - left.getBinary(ordinal) === right.getBinary(ordinal) - - case ArrayType(elementType, _) => (left, right, ordinal) => - if (left.get(ordinal, null) == null || right.get(ordinal, null) == null ){ - left.get(ordinal, null) == right.get(ordinal, null) - } else { - val leftArray = left.getArray(ordinal) - val rightArray = right.getArray(ordinal) - if (leftArray.numElements == rightArray.numElements) { - val len = leftArray.numElements() - val elementValueCompare = newValueCompare(elementType) - var result = true - var idx: Int = 0 - while (idx < len && result) { - result = elementValueCompare(leftArray, rightArray, idx) - idx += 1 - } - result - } else false - } - case _ => throw new RuntimeException(s"Cannot convert field to unsupported data type ${dataType}") - } + private type valueCompare = + (SpecializedGetters, SpecializedGetters, Int) => Boolean + private def newValueCompare(dataType: DataType, + epsilon: Float = 1E-6F): valueCompare = + dataType match { + case NullType => + (left, right, ordinal) => + left.get(ordinal, null) == right.get(ordinal, null) + + case IntegerType => + (left, right, ordinal) => + left.getInt(ordinal) === right.getInt(ordinal) + + case LongType => + (left, right, ordinal) => + left.getLong(ordinal) === right.getLong(ordinal) + + case FloatType => + (left, right, ordinal) => + left.getFloat(ordinal) === (right.getFloat(ordinal) +- epsilon) + + case DoubleType => + (left, right, ordinal) => + left.getDouble(ordinal) === (right.getDouble(ordinal) +- epsilon) + + case DecimalType() => + (left, right, ordinal) => + left + .getDecimal( + ordinal, + DecimalType.USER_DEFAULT.precision, + DecimalType.USER_DEFAULT.scale + ) + .toDouble === + (right + .getDecimal( + ordinal, + DecimalType.USER_DEFAULT.precision, + DecimalType.USER_DEFAULT.scale + ) + .toDouble + +- epsilon) + + case StringType => + (left, right, ordinal) => + left.getUTF8String(ordinal).getBytes === right + .getUTF8String(ordinal) + .getBytes + + case BinaryType => + (left, right, ordinal) => + left.getBinary(ordinal) === right.getBinary(ordinal) + + case ArrayType(elementType, _) => + (left, right, ordinal) => + if (left.get(ordinal, null) == null || right.get(ordinal, null) == null) { + left.get(ordinal, null) == right.get(ordinal, null) + } else { + val leftArray = left.getArray(ordinal) + val rightArray = right.getArray(ordinal) + if (leftArray.numElements == rightArray.numElements) { + val len = leftArray.numElements() + val elementValueCompare = newValueCompare(elementType) + var result = true + var idx: Int = 0 + while (idx < len && result) { + result = elementValueCompare(leftArray, rightArray, idx) + idx += 1 + } + result + } else false + } + case _ => + throw new RuntimeException( + s"Cannot convert field to unsupported data type ${dataType}" + ) + } /** * When all fields in row with given schema are equal or are within eps, returns true; otherwise, returns false. */ - def ~==(right: InternalRow, schema: StructType, epsilon : Float = 1E-6F): Boolean = { + def ~==(right: InternalRow, + schema: StructType, + epsilon: Float = 1E-6F): Boolean = { if (schema != null && schema.fields.size == left.numFields && schema.fields.size == right.numFields) { schema.fields.map(_.dataType).zipWithIndex.forall { case (dataType, idx) => - val valueCompare = newValueCompare(dataType) - valueCompare(left, right, idx) + val valueCompare = newValueCompare(dataType) + valueCompare(left, right, idx) } - } - else false + } else false } } /** - * Implicit class for comparing two rows using absolute tolerance. - */ + * Implicit class for comparing two rows using absolute tolerance. + */ implicit class RowWithAlmostEquals(val left: Row) { /** - * When all fields in row with given schema are equal or are within eps, returns true; otherwise, returns false. - */ + * When all fields in row with given schema are equal or are within eps, returns true; otherwise, returns false. + */ def ~==(right: Row, schema: StructType): Boolean = { if (schema != null && schema.fields.size == left.size && schema.fields.size == right.size) { - val leftRowWithSchema = new GenericRowWithSchema(left.toSeq.toArray, schema) - val rightRowWithSchema = new GenericRowWithSchema(right.toSeq.toArray, schema) + val leftRowWithSchema = + new GenericRowWithSchema(left.toSeq.toArray, schema) + val rightRowWithSchema = + new GenericRowWithSchema(right.toSeq.toArray, schema) leftRowWithSchema ~== rightRowWithSchema - } - else false + } else false } /** - * When all fields in row are equal or are within eps, returns true; otherwise, returns false. - */ - def ~==(right: Row, epsilon : Float = 1E-6F): Boolean = { + * When all fields in row are equal or are within eps, returns true; otherwise, returns false. + */ + def ~==(right: Row, epsilon: Float = 1E-6F): Boolean = { if (left.size === right.size) { val leftDataTypes = left.schema.fields.map(_.dataType) val rightDataTypes = right.schema.fields.map(_.dataType) @@ -227,54 +256,99 @@ object TestingUtils extends Matchers { left.getDouble(i) === (right.getDouble(i) +- epsilon) case ((BinaryType, BinaryType), i) => - left.getAs[Array[Byte]](i).toSeq === right.getAs[Array[Byte]](i).toSeq - - case ((ArrayType(FloatType,_), ArrayType(FloatType,_)), i) => - val leftArray = ArrayData.toArrayData(left.get(i)).toFloatArray().toSeq - val rightArray = ArrayData.toArrayData(right.get(i)).toFloatArray().toSeq + left.getAs[Array[Byte]](i).toSeq === right + .getAs[Array[Byte]](i) + .toSeq + + case ((ArrayType(FloatType, _), ArrayType(FloatType, _)), i) => + val leftArray = + ArrayData.toArrayData(left.get(i)).toFloatArray().toSeq + val rightArray = + ArrayData.toArrayData(right.get(i)).toFloatArray().toSeq leftArray ~== (rightArray, epsilon) - case ((ArrayType(DoubleType,_), ArrayType(DoubleType,_)), i) => - val leftArray = ArrayData.toArrayData(left.get(i)).toDoubleArray().toSeq - val rightArray = ArrayData.toArrayData(right.get(i)).toDoubleArray().toSeq + case ((ArrayType(DoubleType, _), ArrayType(DoubleType, _)), i) => + val leftArray = + ArrayData.toArrayData(left.get(i)).toDoubleArray().toSeq + val rightArray = + ArrayData.toArrayData(right.get(i)).toDoubleArray().toSeq leftArray ~== (rightArray, epsilon) - case ((ArrayType(BinaryType,_), ArrayType(BinaryType,_)), i) => - val leftArray = ArrayData.toArrayData(left.get(i)).toArray[Array[Byte]](BinaryType).map(_.toSeq).toSeq - val rightArray = ArrayData.toArrayData(right.get(i)).toArray[Array[Byte]](BinaryType).map(_.toSeq).toSeq + case ((ArrayType(BinaryType, _), ArrayType(BinaryType, _)), i) => + val leftArray = ArrayData + .toArrayData(left.get(i)) + .toArray[Array[Byte]](BinaryType) + .map(_.toSeq) + .toSeq + val rightArray = ArrayData + .toArrayData(right.get(i)) + .toArray[Array[Byte]](BinaryType) + .map(_.toSeq) + .toSeq leftArray === rightArray - case ((ArrayType(ArrayType(FloatType,_),_), ArrayType(ArrayType(FloatType,_),_)), i) => - val leftArrays = ArrayData.toArrayData(left.get(i)).array.toSeq.map {arr => - ArrayData.toArrayData(arr).toFloatArray().toSeq - } - val rightArrays = ArrayData.toArrayData(right.get(i)).array.toSeq.map {arr => - ArrayData.toArrayData(arr).toFloatArray().toSeq - } + case ( + ( + ArrayType(ArrayType(FloatType, _), _), + ArrayType(ArrayType(FloatType, _), _) + ), + i + ) => + val leftArrays = + ArrayData.toArrayData(left.get(i)).array.toSeq.map { arr => + ArrayData.toArrayData(arr).toFloatArray().toSeq + } + val rightArrays = + ArrayData.toArrayData(right.get(i)).array.toSeq.map { arr => + ArrayData.toArrayData(arr).toFloatArray().toSeq + } leftArrays ~== (rightArrays, epsilon) - case ((ArrayType(ArrayType(DoubleType,_),_), ArrayType(ArrayType(DoubleType,_),_)), i) => - val leftArrays = ArrayData.toArrayData(left.get(i)).array.toSeq.map {arr => - ArrayData.toArrayData(arr).toDoubleArray().toSeq - } - val rightArrays = ArrayData.toArrayData(right.get(i)).array.toSeq.map {arr => - ArrayData.toArrayData(arr).toDoubleArray().toSeq - } + case ( + ( + ArrayType(ArrayType(DoubleType, _), _), + ArrayType(ArrayType(DoubleType, _), _) + ), + i + ) => + val leftArrays = + ArrayData.toArrayData(left.get(i)).array.toSeq.map { arr => + ArrayData.toArrayData(arr).toDoubleArray().toSeq + } + val rightArrays = + ArrayData.toArrayData(right.get(i)).array.toSeq.map { arr => + ArrayData.toArrayData(arr).toDoubleArray().toSeq + } leftArrays ~== (rightArrays, epsilon) - case ((ArrayType(ArrayType(BinaryType,_),_), ArrayType(ArrayType(BinaryType,_),_)), i) => - val leftArrays = ArrayData.toArrayData(left.get(i)).array.toSeq.map {arr => - ArrayData.toArrayData(arr).toArray[Array[Byte]](BinaryType).map(_.toSeq).toSeq - } - val rightArrays = ArrayData.toArrayData(right.get(i)).array.toSeq.map {arr => - ArrayData.toArrayData(arr).toArray[Array[Byte]](BinaryType).map(_.toSeq).toSeq - } + case ( + ( + ArrayType(ArrayType(BinaryType, _), _), + ArrayType(ArrayType(BinaryType, _), _) + ), + i + ) => + val leftArrays = + ArrayData.toArrayData(left.get(i)).array.toSeq.map { arr => + ArrayData + .toArrayData(arr) + .toArray[Array[Byte]](BinaryType) + .map(_.toSeq) + .toSeq + } + val rightArrays = + ArrayData.toArrayData(right.get(i)).array.toSeq.map { arr => + ArrayData + .toArrayData(arr) + .toArray[Array[Byte]](BinaryType) + .map(_.toSeq) + .toSeq + } leftArrays === rightArrays - case((a,b), i) => left.get(i) === right.get(i) + case ((a, b), i) => left.get(i) === right.get(i) } - } - else false + } else false } } -} \ No newline at end of file +} From 4548ba66aab0ee97b5588dca2136f5e80ffc663d Mon Sep 17 00:00:00 2001 From: Greg Roodt Date: Thu, 27 Aug 2020 13:55:03 +1000 Subject: [PATCH 2/7] 2.11 passing. --- pom.xml | 2 +- .../spark/datasources/tfrecord/SharedSparkSessionSuite.scala | 5 +++-- .../datasources/tfrecord/TFRecordDeserializerTest.scala | 5 +++-- .../spark/datasources/tfrecord/TFRecordSerializerTest.scala | 5 +++-- .../linkedin/spark/datasources/tfrecord/TestingUtils.scala | 2 +- 5 files changed, 11 insertions(+), 8 deletions(-) diff --git a/pom.xml b/pom.xml index ce6fb09..b911b1e 100644 --- a/pom.xml +++ b/pom.xml @@ -342,7 +342,7 @@ scala-2.11 2.11 - 2.11.8 + 2.11.12 diff --git a/src/test/scala/com/linkedin/spark/datasources/tfrecord/SharedSparkSessionSuite.scala b/src/test/scala/com/linkedin/spark/datasources/tfrecord/SharedSparkSessionSuite.scala index 7f12ef5..720b9d6 100644 --- a/src/test/scala/com/linkedin/spark/datasources/tfrecord/SharedSparkSessionSuite.scala +++ b/src/test/scala/com/linkedin/spark/datasources/tfrecord/SharedSparkSessionSuite.scala @@ -21,9 +21,10 @@ import org.apache.commons.io.FileUtils import org.apache.spark.SharedSparkSession import org.junit.{After, Before} import org.scalatest._ -import matchers._ +import matchers.should._ +import org.scalatest.wordspec.AnyWordSpecLike -trait BaseSuite extends WordSpecLike with Matchers with BeforeAndAfterAll +trait BaseSuite extends AnyWordSpecLike with Matchers with BeforeAndAfterAll class SharedSparkSessionSuite extends SharedSparkSession with BaseSuite { val TF_SANDBOX_DIR = "tf-sandbox" diff --git a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala index 51250cf..b59b622 100644 --- a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala +++ b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala @@ -23,9 +23,10 @@ import org.apache.spark.unsafe.types.UTF8String import org.tensorflow.example._ import TestingUtils._ import org.scalatest._ -import matchers._ +import matchers.should._ +import org.scalatest.wordspec.AnyWordSpecLike -class TFRecordDeserializerTest extends WordSpec with Matchers { +class TFRecordDeserializerTest extends AnyWordSpecLike with Matchers { val intFeature = Feature .newBuilder() .setInt64List(Int64List.newBuilder().addValue(1)) diff --git a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala index 47b9e00..0007312 100644 --- a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala +++ b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala @@ -21,12 +21,13 @@ import org.apache.spark.sql.types.{StructField, _} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.unsafe.types.UTF8String import org.scalatest._ -import matchers._ +import matchers.should._ +import org.scalatest.wordspec.AnyWordSpecLike import scala.collection.JavaConverters._ import TestingUtils._ -class TFRecordSerializerTest extends WordSpec with Matchers { +class TFRecordSerializerTest extends AnyWordSpecLike with Matchers { private def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray) diff --git a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TestingUtils.scala b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TestingUtils.scala index 5d9380a..bdb4fb7 100644 --- a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TestingUtils.scala +++ b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TestingUtils.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ import org.scalatest._ -import matchers._ +import matchers.should._ object TestingUtils extends Matchers { From bc22671aa42aaf79f0adbb515b77e03b9883e5f4 Mon Sep 17 00:00:00 2001 From: Greg Roodt Date: Thu, 27 Aug 2020 14:15:30 +1000 Subject: [PATCH 3/7] Travis matrix build. --- .travis.yml | 11 ++++++++--- pom.xml | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index b802e22..408f1ee 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,12 +1,17 @@ dist: trusty language: scala -scala: - - 2.11.8 git: depth: 3 jdk: - oraclejdk8 -script: "mvn test -B" + +matrix: + include: + - scala: 2.11.12 + script: "mvn test -B -Pscala-2.11" + + - scala: 2.12.12 + script: "mvn test -B -Pscala-2.12" # safelist branches: diff --git a/pom.xml b/pom.xml index b911b1e..8ab2892 100644 --- a/pom.xml +++ b/pom.xml @@ -350,7 +350,7 @@ scala-2.12 2.12 - 2.12.10 + 2.12.12 From aceb08c32e1aef1467d7a3ff1d6bb6b9f9d7d100 Mon Sep 17 00:00:00 2001 From: Greg Roodt Date: Thu, 27 Aug 2020 14:48:32 +1000 Subject: [PATCH 4/7] Spark 2.4 + Scala 2.11, Spark 3.0 + Scala 2.12. --- pom.xml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 8ab2892..c65c2fa 100644 --- a/pom.xml +++ b/pom.xml @@ -32,7 +32,6 @@ 3.2.2 3.0 1.8 - 2.4.6 4.11 1.15.0 @@ -343,6 +342,7 @@ 2.11 2.11.12 + 2.4.6 @@ -351,6 +351,7 @@ 2.12 2.12.12 + 3.0.0 From a41e62c1dc959dee6dee54616cd1d6c87251c481 Mon Sep 17 00:00:00 2001 From: Greg Roodt Date: Thu, 27 Aug 2020 22:54:44 +1000 Subject: [PATCH 5/7] Back to cross-compilation. --- pom.xml | 5 ++--- .../datasources/tfrecord/SharedSparkSessionSuite.scala | 6 ++---- .../datasources/tfrecord/TFRecordDeserializerTest.scala | 6 ++---- .../spark/datasources/tfrecord/TFRecordSerializerTest.scala | 6 ++---- .../linkedin/spark/datasources/tfrecord/TestingUtils.scala | 3 +-- 5 files changed, 9 insertions(+), 17 deletions(-) diff --git a/pom.xml b/pom.xml index c65c2fa..c0b8298 100644 --- a/pom.xml +++ b/pom.xml @@ -29,7 +29,8 @@ UTF-8 3.2.2 1.0 - 3.2.2 + 3.0.8 + 2.4.6 3.0 1.8 4.11 @@ -342,7 +343,6 @@ 2.11 2.11.12 - 2.4.6 @@ -351,7 +351,6 @@ 2.12 2.12.12 - 3.0.0 diff --git a/src/test/scala/com/linkedin/spark/datasources/tfrecord/SharedSparkSessionSuite.scala b/src/test/scala/com/linkedin/spark/datasources/tfrecord/SharedSparkSessionSuite.scala index 720b9d6..34e574b 100644 --- a/src/test/scala/com/linkedin/spark/datasources/tfrecord/SharedSparkSessionSuite.scala +++ b/src/test/scala/com/linkedin/spark/datasources/tfrecord/SharedSparkSessionSuite.scala @@ -20,11 +20,9 @@ import java.io.File import org.apache.commons.io.FileUtils import org.apache.spark.SharedSparkSession import org.junit.{After, Before} -import org.scalatest._ -import matchers.should._ -import org.scalatest.wordspec.AnyWordSpecLike +import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike} -trait BaseSuite extends AnyWordSpecLike with Matchers with BeforeAndAfterAll +trait BaseSuite extends WordSpecLike with Matchers with BeforeAndAfterAll class SharedSparkSessionSuite extends SharedSparkSession with BaseSuite { val TF_SANDBOX_DIR = "tf-sandbox" diff --git a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala index b59b622..f2a4222 100644 --- a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala +++ b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala @@ -22,11 +22,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.tensorflow.example._ import TestingUtils._ -import org.scalatest._ -import matchers.should._ -import org.scalatest.wordspec.AnyWordSpecLike +import org.scalatest.{Matchers, WordSpecLike} -class TFRecordDeserializerTest extends AnyWordSpecLike with Matchers { +class TFRecordDeserializerTest extends WordSpecLike with Matchers { val intFeature = Feature .newBuilder() .setInt64List(Int64List.newBuilder().addValue(1)) diff --git a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala index 0007312..65e8874 100644 --- a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala +++ b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala @@ -20,14 +20,12 @@ import org.tensorflow.example._ import org.apache.spark.sql.types.{StructField, _} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.unsafe.types.UTF8String -import org.scalatest._ -import matchers.should._ -import org.scalatest.wordspec.AnyWordSpecLike import scala.collection.JavaConverters._ import TestingUtils._ +import org.scalatest.{Matchers, WordSpecLike} -class TFRecordSerializerTest extends AnyWordSpecLike with Matchers { +class TFRecordSerializerTest extends WordSpecLike with Matchers { private def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray) diff --git a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TestingUtils.scala b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TestingUtils.scala index bdb4fb7..d9e2e06 100644 --- a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TestingUtils.scala +++ b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TestingUtils.scala @@ -23,8 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.{ } import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ -import org.scalatest._ -import matchers.should._ +import org.scalatest.Matchers object TestingUtils extends Matchers { From bf675a081a74f6275d621061a4ca6979da2226f2 Mon Sep 17 00:00:00 2001 From: Greg Roodt Date: Fri, 28 Aug 2020 07:59:30 +1000 Subject: [PATCH 6/7] Revert formatting back to original --- .../tfrecord/SharedSparkSessionSuite.scala | 30 +- .../tfrecord/TFRecordDeserializerTest.scala | 366 +++++--------- .../tfrecord/TFRecordSerializerTest.scala | 453 ++++++------------ .../datasources/tfrecord/TestingUtils.scala | 393 +++++++-------- 4 files changed, 435 insertions(+), 807 deletions(-) diff --git a/src/test/scala/com/linkedin/spark/datasources/tfrecord/SharedSparkSessionSuite.scala b/src/test/scala/com/linkedin/spark/datasources/tfrecord/SharedSparkSessionSuite.scala index 34e574b..72b2509 100644 --- a/src/test/scala/com/linkedin/spark/datasources/tfrecord/SharedSparkSessionSuite.scala +++ b/src/test/scala/com/linkedin/spark/datasources/tfrecord/SharedSparkSessionSuite.scala @@ -1,18 +1,18 @@ /** - * Copyright 2016 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.linkedin.spark.datasources.tfrecord import java.io.File @@ -22,6 +22,7 @@ import org.apache.spark.SharedSparkSession import org.junit.{After, Before} import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike} + trait BaseSuite extends WordSpecLike with Matchers with BeforeAndAfterAll class SharedSparkSessionSuite extends SharedSparkSession with BaseSuite { @@ -41,3 +42,4 @@ class SharedSparkSessionSuite extends SharedSparkSession with BaseSuite { super.tearDown() } } + diff --git a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala index f2a4222..6899b2d 100644 --- a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala +++ b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala @@ -1,18 +1,18 @@ /** - * Copyright 2016 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.linkedin.spark.datasources.tfrecord import com.google.protobuf.ByteString @@ -20,113 +20,50 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.scalatest.{Matchers, WordSpec} import org.tensorflow.example._ import TestingUtils._ -import org.scalatest.{Matchers, WordSpecLike} - -class TFRecordDeserializerTest extends WordSpecLike with Matchers { - val intFeature = Feature - .newBuilder() - .setInt64List(Int64List.newBuilder().addValue(1)) - .build() - val longFeature = Feature - .newBuilder() - .setInt64List(Int64List.newBuilder().addValue(23L)) - .build() - val floatFeature = Feature - .newBuilder() - .setFloatList(FloatList.newBuilder().addValue(10.0F)) - .build() - val doubleFeature = Feature - .newBuilder() - .setFloatList(FloatList.newBuilder().addValue(14.0F)) - .build() - val decimalFeature = Feature - .newBuilder() - .setFloatList(FloatList.newBuilder().addValue(2.5F)) - .build() - val longArrFeature = Feature - .newBuilder() - .setInt64List(Int64List.newBuilder().addValue(-2L).addValue(7L).build()) - .build() - val doubleArrFeature = Feature - .newBuilder() - .setFloatList(FloatList.newBuilder().addValue(1F).addValue(2F).build()) - .build() - val decimalArrFeature = Feature - .newBuilder() - .setFloatList(FloatList.newBuilder().addValue(3F).addValue(5F).build()) - .build() - val strFeature = Feature - .newBuilder() - .setBytesList( - BytesList - .newBuilder() - .addValue(ByteString.copyFrom("r1".getBytes)) - .build() - ) - .build() - val strListFeature = Feature - .newBuilder() - .setBytesList( - BytesList - .newBuilder() - .addValue(ByteString.copyFrom("r2".getBytes)) - .addValue(ByteString.copyFrom("r3".getBytes)) - .build() - ) - .build() - val binaryFeature = Feature - .newBuilder() - .setBytesList( - BytesList.newBuilder().addValue(ByteString.copyFrom("r4".getBytes)) - ) - .build() - val binaryListFeature = Feature - .newBuilder() - .setBytesList( - BytesList - .newBuilder() - .addValue(ByteString.copyFrom("r5".getBytes)) - .addValue(ByteString.copyFrom("r6".getBytes)) - .build() - ) - .build() - private def createArray(values: Any*): ArrayData = - new GenericArrayData(values.toArray) + +class TFRecordDeserializerTest extends WordSpec with Matchers { + val intFeature = Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(1)).build() + val longFeature = Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(23L)).build() + val floatFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(10.0F)).build() + val doubleFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(14.0F)).build() + val decimalFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(2.5F)).build() + val longArrFeature = Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(-2L).addValue(7L).build()).build() + val doubleArrFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(1F).addValue(2F).build()).build() + val decimalArrFeature = Feature.newBuilder().setFloatList(FloatList.newBuilder().addValue(3F).addValue(5F).build()).build() + val strFeature = Feature.newBuilder().setBytesList(BytesList.newBuilder().addValue(ByteString.copyFrom("r1".getBytes)).build()).build() + val strListFeature =Feature.newBuilder().setBytesList(BytesList.newBuilder().addValue(ByteString.copyFrom("r2".getBytes)) + .addValue(ByteString.copyFrom("r3".getBytes)).build()).build() + val binaryFeature = Feature.newBuilder().setBytesList(BytesList.newBuilder().addValue(ByteString.copyFrom("r4".getBytes))).build() + val binaryListFeature = Feature.newBuilder().setBytesList(BytesList.newBuilder().addValue(ByteString.copyFrom("r5".getBytes)) + .addValue(ByteString.copyFrom("r6".getBytes)).build()).build() + + private def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray) "Deserialize tfrecord to spark internalRow" should { "Serialize tfrecord example to spark internalRow" in { - val schema = StructType( - List( - StructField("IntegerLabel", IntegerType), - StructField("LongLabel", LongType), - StructField("FloatLabel", FloatType), - StructField("DoubleLabel", DoubleType), - StructField("DecimalLabel", DataTypes.createDecimalType()), - StructField("LongArrayLabel", ArrayType(LongType)), - StructField("DoubleArrayLabel", ArrayType(DoubleType)), - StructField( - "DecimalArrayLabel", - ArrayType(DataTypes.createDecimalType()) - ), - StructField("StrLabel", StringType), - StructField("StrArrayLabel", ArrayType(StringType)), - StructField("BinaryTypeLabel", BinaryType), - StructField("BinaryTypeArrayLabel", ArrayType(BinaryType)) - ) - ) + val schema = StructType(List( + StructField("IntegerLabel", IntegerType), + StructField("LongLabel", LongType), + StructField("FloatLabel", FloatType), + StructField("DoubleLabel", DoubleType), + StructField("DecimalLabel", DataTypes.createDecimalType()), + StructField("LongArrayLabel", ArrayType(LongType)), + StructField("DoubleArrayLabel", ArrayType(DoubleType)), + StructField("DecimalArrayLabel", ArrayType(DataTypes.createDecimalType())), + StructField("StrLabel", StringType), + StructField("StrArrayLabel", ArrayType(StringType)), + StructField("BinaryTypeLabel", BinaryType), + StructField("BinaryTypeArrayLabel", ArrayType(BinaryType)) + )) val expectedInternalRow = InternalRow.fromSeq( - Array[Any]( - 1, - 23L, - 10.0F, - 14.0, - Decimal(2.5d), - createArray(-2L, 7L), + Array[Any](1, 23L, 10.0F, 14.0, Decimal(2.5d), + createArray(-2L,7L), createArray(1.0, 2.0), createArray(Decimal(3.0), Decimal(5.0)), UTF8String.fromString("r1"), @@ -137,8 +74,7 @@ class TFRecordDeserializerTest extends WordSpecLike with Matchers { ) //Build example - val features = Features - .newBuilder() + val features = Features.newBuilder() .putFeature("IntegerLabel", intFeature) .putFeature("LongLabel", longFeature) .putFeature("FloatLabel", floatFeature) @@ -152,82 +88,49 @@ class TFRecordDeserializerTest extends WordSpecLike with Matchers { .putFeature("BinaryTypeLabel", binaryFeature) .putFeature("BinaryTypeArrayLabel", binaryListFeature) .build() - val example = Example - .newBuilder() + val example = Example.newBuilder() .setFeatures(features) .build() val deserializer = new TFRecordDeserializer(schema) val actualInternalRow = deserializer.deserializeExample(example) - assert(actualInternalRow ~== (expectedInternalRow, schema)) + assert(actualInternalRow ~== (expectedInternalRow,schema)) } "Serialize spark internalRow to tfrecord sequenceExample" in { - val schema = StructType( - List( - StructField("FloatLabel", FloatType), - StructField("LongArrayOfArrayLabel", ArrayType(ArrayType(LongType))), - StructField( - "FloatArrayOfArrayLabel", - ArrayType(ArrayType(FloatType)) - ), - StructField( - "DecimalArrayOfArrayLabel", - ArrayType(ArrayType(DataTypes.createDecimalType())) - ), - StructField("StrArrayOfArrayLabel", ArrayType(ArrayType(StringType))), - StructField("ByteArrayOfArrayLabel", ArrayType(ArrayType(BinaryType))) - ) - ) + val schema = StructType(List( + StructField("FloatLabel", FloatType), + StructField("LongArrayOfArrayLabel", ArrayType(ArrayType(LongType))), + StructField("FloatArrayOfArrayLabel", ArrayType(ArrayType(FloatType))), + StructField("DecimalArrayOfArrayLabel", ArrayType(ArrayType(DataTypes.createDecimalType()))), + StructField("StrArrayOfArrayLabel", ArrayType(ArrayType(StringType))), + StructField("ByteArrayOfArrayLabel", ArrayType(ArrayType(BinaryType))) + )) val expectedInternalRow = InternalRow.fromSeq( - Array[Any]( - 10.0F, + Array[Any](10.0F, createArray(createArray(-2L, 7L)), createArray(createArray(10.0F), createArray(1.0F, 2.0F)), createArray(createArray(Decimal(3.0), Decimal(5.0))), - createArray( - createArray( - UTF8String.fromString("r2"), - UTF8String.fromString("r3") - ), - createArray(UTF8String.fromString("r1")) - ), - createArray( - createArray("r5".getBytes, "r6".getBytes), - createArray("r4".getBytes) - ) + createArray(createArray(UTF8String.fromString("r2"), UTF8String.fromString("r3")), + createArray(UTF8String.fromString("r1"))), + createArray(createArray("r5".getBytes, "r6".getBytes), createArray("r4".getBytes)) ) ) //Build sequence example - val int64FeatureList = - FeatureList.newBuilder().addFeature(longArrFeature).build() - val floatFeatureList = FeatureList - .newBuilder() - .addFeature(floatFeature) - .addFeature(doubleArrFeature) - .build() - val decimalFeatureList = - FeatureList.newBuilder().addFeature(decimalArrFeature).build() - val stringFeatureList = FeatureList - .newBuilder() - .addFeature(strListFeature) - .addFeature(strFeature) - .build() - val binaryFeatureList = FeatureList - .newBuilder() - .addFeature(binaryListFeature) - .addFeature(binaryFeature) - .build() + val int64FeatureList = FeatureList.newBuilder().addFeature(longArrFeature).build() + val floatFeatureList = FeatureList.newBuilder().addFeature(floatFeature).addFeature(doubleArrFeature).build() + val decimalFeatureList = FeatureList.newBuilder().addFeature(decimalArrFeature).build() + val stringFeatureList = FeatureList.newBuilder().addFeature(strListFeature).addFeature(strFeature).build() + val binaryFeatureList = FeatureList.newBuilder().addFeature(binaryListFeature).addFeature(binaryFeature).build() - val features = Features - .newBuilder() + + val features = Features.newBuilder() .putFeature("FloatLabel", floatFeature) - val featureLists = FeatureLists - .newBuilder() + val featureLists = FeatureLists.newBuilder() .putFeatureList("LongArrayOfArrayLabel", int64FeatureList) .putFeatureList("FloatArrayOfArrayLabel", floatFeatureList) .putFeatureList("DecimalArrayOfArrayLabel", decimalFeatureList) @@ -235,29 +138,24 @@ class TFRecordDeserializerTest extends WordSpecLike with Matchers { .putFeatureList("ByteArrayOfArrayLabel", binaryFeatureList) .build() - val seqExample = SequenceExample - .newBuilder() + val seqExample = SequenceExample.newBuilder() .setContext(features) .setFeatureLists(featureLists) .build() val deserializer = new TFRecordDeserializer(schema) - val actualInternalRow = - deserializer.deserializeSequenceExample(seqExample) + val actualInternalRow = deserializer.deserializeSequenceExample(seqExample) assert(actualInternalRow ~== (expectedInternalRow, schema)) } "Throw an exception for unsupported data types" in { val features = Features.newBuilder().putFeature("MapLabel1", intFeature) - val int64FeatureList = - FeatureList.newBuilder().addFeature(longArrFeature).build() - val featureLists = - FeatureLists.newBuilder().putFeatureList("MapLabel2", int64FeatureList) + val int64FeatureList = FeatureList.newBuilder().addFeature(longArrFeature).build() + val featureLists = FeatureLists.newBuilder().putFeatureList("MapLabel2", int64FeatureList) intercept[RuntimeException] { - val example = Example - .newBuilder() + val example = Example.newBuilder() .setFeatures(features) .build() val schema = StructType(List(StructField("MapLabel1", TimestampType))) @@ -266,8 +164,7 @@ class TFRecordDeserializerTest extends WordSpecLike with Matchers { } intercept[RuntimeException] { - val seqExample = SequenceExample - .newBuilder() + val seqExample = SequenceExample.newBuilder() .setContext(features) .setFeatureLists(featureLists) .build() @@ -278,65 +175,45 @@ class TFRecordDeserializerTest extends WordSpecLike with Matchers { } "Throw an exception for non-nullable data types" in { - val features = - Features.newBuilder().putFeature("FloatLabel", floatFeature) - val int64FeatureList = - FeatureList.newBuilder().addFeature(longArrFeature).build() - val featureLists = FeatureLists - .newBuilder() - .putFeatureList("LongArrayOfArrayLabel", int64FeatureList) + val features = Features.newBuilder().putFeature("FloatLabel", floatFeature) + val int64FeatureList = FeatureList.newBuilder().addFeature(longArrFeature).build() + val featureLists = FeatureLists.newBuilder().putFeatureList("LongArrayOfArrayLabel", int64FeatureList) intercept[NullPointerException] { - val example = Example - .newBuilder() + val example = Example.newBuilder() .setFeatures(features) .build() - val schema = StructType( - List(StructField("MissingLabel", FloatType, nullable = false)) - ) + val schema = StructType(List(StructField("MissingLabel", FloatType, nullable = false))) val deserializer = new TFRecordDeserializer(schema) deserializer.deserializeExample(example) } intercept[NullPointerException] { - val seqExample = SequenceExample - .newBuilder() + val seqExample = SequenceExample.newBuilder() .setContext(features) .setFeatureLists(featureLists) .build() - val schema = StructType( - List( - StructField( - "MissingLabel", - ArrayType(ArrayType(LongType)), - nullable = false - ) - ) - ) + val schema = StructType(List(StructField("MissingLabel", ArrayType(ArrayType(LongType)), nullable = false))) val deserializer = new TFRecordDeserializer(schema) deserializer.deserializeSequenceExample(seqExample) } } + "Return null fields for nullable data types" in { - val features = - Features.newBuilder().putFeature("FloatLabel", floatFeature) - val int64FeatureList = - FeatureList.newBuilder().addFeature(longArrFeature).build() - val featureLists = FeatureLists - .newBuilder() - .putFeatureList("LongArrayOfArrayLabel", int64FeatureList) + val features = Features.newBuilder().putFeature("FloatLabel", floatFeature) + val int64FeatureList = FeatureList.newBuilder().addFeature(longArrFeature).build() + val featureLists = FeatureLists.newBuilder().putFeatureList("LongArrayOfArrayLabel", int64FeatureList) // Deserialize Example - val schema1 = StructType( - List( - StructField("FloatLabel", FloatType), - StructField("MissingLabel", FloatType, nullable = true) - ) + val schema1 = StructType(List( + StructField("FloatLabel", FloatType), + StructField("MissingLabel", FloatType, nullable = true)) + ) + val expectedInternalRow1 = InternalRow.fromSeq( + Array[Any](10.0F, null) ) - val expectedInternalRow1 = InternalRow.fromSeq(Array[Any](10.0F, null)) - val example = Example - .newBuilder() + val example = Example.newBuilder() .setFeatures(features) .build() val deserializer1 = new TFRecordDeserializer(schema1) @@ -344,31 +221,27 @@ class TFRecordDeserializerTest extends WordSpecLike with Matchers { assert(actualInternalRow1 ~== (expectedInternalRow1, schema1)) // Deserialize SequenceExample - val schema2 = StructType( - List( - StructField("LongArrayOfArrayLabel", ArrayType(ArrayType(LongType))), - StructField( - "MissingLabel", - ArrayType(ArrayType(LongType)), - nullable = true - ) - ) + val schema2 = StructType(List( + StructField("LongArrayOfArrayLabel", ArrayType(ArrayType(LongType))), + StructField("MissingLabel", ArrayType(ArrayType(LongType)), nullable = true)) + ) + val expectedInternalRow2 = InternalRow.fromSeq( + Array[Any]( + createArray(createArray(-2L, 7L)), null) ) - val expectedInternalRow2 = - InternalRow.fromSeq(Array[Any](createArray(createArray(-2L, 7L)), null)) - val seqExample = SequenceExample - .newBuilder() + val seqExample = SequenceExample.newBuilder() .setContext(features) .setFeatureLists(featureLists) .build() val deserializer2 = new TFRecordDeserializer(schema2) - val actualInternalRow2 = - deserializer2.deserializeSequenceExample(seqExample) + val actualInternalRow2 = deserializer2.deserializeSequenceExample(seqExample) assert(actualInternalRow2 ~== (expectedInternalRow2, schema2)) } - val schema = StructType(Array(StructField("LongLabel", LongType))) + val schema = StructType(Array( + StructField("LongLabel", LongType)) + ) val deserializer = new TFRecordDeserializer(schema) "Test Int64ListFeature2SeqLong" in { @@ -398,16 +271,9 @@ class TFRecordDeserializerTest extends WordSpecLike with Matchers { } "Test bytesListFeature2SeqArrayByte" in { - val bytesList = BytesList - .newBuilder() - .addValue(ByteString.copyFrom("str-input".getBytes)) - .build() + val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build() val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build() - assert( - deserializer - .bytesListFeature2SeqArrayByte(bytesFeature) - .head === "str-input".getBytes.deep - ) + assert(deserializer.bytesListFeature2SeqArrayByte(bytesFeature).head === "str-input".getBytes.deep) // Throw exception if type doesn't match intercept[RuntimeException] { @@ -418,16 +284,10 @@ class TFRecordDeserializerTest extends WordSpecLike with Matchers { } "Test bytesListFeature2SeqString" in { - val bytesList = BytesList - .newBuilder() - .addValue(ByteString.copyFrom("alice".getBytes)) - .addValue(ByteString.copyFrom("bob".getBytes)) - .build() + val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("alice".getBytes)) + .addValue(ByteString.copyFrom("bob".getBytes)).build() val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build() - assert( - deserializer - .bytesListFeature2SeqString(bytesFeature) === Seq("alice", "bob") - ) + assert(deserializer.bytesListFeature2SeqString(bytesFeature) === Seq("alice", "bob")) // Throw exception if type doesn't match intercept[RuntimeException] { diff --git a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala index 65e8874..d1671a5 100644 --- a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala +++ b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala @@ -1,18 +1,18 @@ /** - * Copyright 2016 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.linkedin.spark.datasources.tfrecord import org.apache.spark.sql.catalyst.InternalRow @@ -20,33 +20,26 @@ import org.tensorflow.example._ import org.apache.spark.sql.types.{StructField, _} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.unsafe.types.UTF8String +import org.scalatest.{Matchers, WordSpec} import scala.collection.JavaConverters._ import TestingUtils._ -import org.scalatest.{Matchers, WordSpecLike} -class TFRecordSerializerTest extends WordSpecLike with Matchers { +class TFRecordSerializerTest extends WordSpec with Matchers { - private def createArray(values: Any*): ArrayData = - new GenericArrayData(values.toArray) + private def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray) "Serialize spark internalRow to tfrecord" should { "Serialize decimal internalRow to tfrecord example" in { - val schemaStructType = StructType( - Array( - StructField("DecimalLabel", DataTypes.createDecimalType()), - StructField( - "DecimalArrayLabel", - ArrayType(DataTypes.createDecimalType()) - ) - ) - ) + val schemaStructType = StructType(Array( + StructField("DecimalLabel", DataTypes.createDecimalType()), + StructField("DecimalArrayLabel", ArrayType(DataTypes.createDecimalType())) + )) val serializer = new TFRecordSerializer(schemaStructType) val decimalArray = Array(Decimal(4.0), Decimal(8.0)) - val rowArray = - Array[Any](Decimal(6.5), ArrayData.toArrayData(decimalArray)) + val rowArray = Array[Any](Decimal(6.5), ArrayData.toArrayData(decimalArray)) val internalRow = InternalRow.fromSeq(rowArray) //Encode Sql InternalRow to TensorFlow example @@ -56,57 +49,37 @@ class TFRecordSerializerTest extends WordSpecLike with Matchers { val featureMap = example.getFeatures.getFeatureMap.asScala assert(featureMap.size == rowArray.length) - assert( - featureMap("DecimalLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER - ) + assert(featureMap("DecimalLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) assert(featureMap("DecimalLabel").getFloatList.getValue(0) == 6.5F) - assert( - featureMap("DecimalArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER - ) - assert( - featureMap("DecimalArrayLabel").getFloatList.getValueList.asScala.toSeq - .map(_.toFloat) ~== decimalArray.map(_.toFloat) - ) + assert(featureMap("DecimalArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) + assert(featureMap("DecimalArrayLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== decimalArray.map(_.toFloat)) } "Serialize complex internalRow to tfrecord example" in { - val schemaStructType = StructType( - Array( - StructField("IntegerLabel", IntegerType), - StructField("LongLabel", LongType), - StructField("FloatLabel", FloatType), - StructField("DoubleLabel", DoubleType), - StructField("DecimalLabel", DataTypes.createDecimalType()), - StructField("DoubleArrayLabel", ArrayType(DoubleType)), - StructField( - "DecimalArrayLabel", - ArrayType(DataTypes.createDecimalType()) - ), - StructField("StrLabel", StringType), - StructField("StrArrayLabel", ArrayType(StringType)), - StructField("BinaryLabel", BinaryType), - StructField("BinaryArrayLabel", ArrayType(BinaryType)) - ) - ) + val schemaStructType = StructType(Array( + StructField("IntegerLabel", IntegerType), + StructField("LongLabel", LongType), + StructField("FloatLabel", FloatType), + StructField("DoubleLabel", DoubleType), + StructField("DecimalLabel", DataTypes.createDecimalType()), + StructField("DoubleArrayLabel", ArrayType(DoubleType)), + StructField("DecimalArrayLabel", ArrayType(DataTypes.createDecimalType())), + StructField("StrLabel", StringType), + StructField("StrArrayLabel", ArrayType(StringType)), + StructField("BinaryLabel", BinaryType), + StructField("BinaryArrayLabel", ArrayType(BinaryType)) + )) val doubleArray = Array(1.1, 111.1, 11111.1) val decimalArray = Array(Decimal(4.0), Decimal(8.0)) - val byteArray = - Array[Byte](0xde.toByte, 0xad.toByte, 0xbe.toByte, 0xef.toByte) + val byteArray = Array[Byte](0xde.toByte, 0xad.toByte, 0xbe.toByte, 0xef.toByte) val byteArray1 = Array[Byte](-128, 23, 127) - val rowArray = Array[Any]( - 1, - 23L, - 10.0F, - 14.0, - Decimal(6.5), + val rowArray = Array[Any](1, 23L, 10.0F, 14.0, Decimal(6.5), ArrayData.toArrayData(doubleArray), ArrayData.toArrayData(decimalArray), UTF8String.fromString("r1"), - ArrayData.toArrayData( - Array(UTF8String.fromString("r2"), UTF8String.fromString("r3")) - ), + ArrayData.toArrayData(Array(UTF8String.fromString("r2"), UTF8String.fromString("r3"))), byteArray, ArrayData.toArrayData(Array(byteArray, byteArray1)) ) @@ -120,139 +93,78 @@ class TFRecordSerializerTest extends WordSpecLike with Matchers { val featureMap = example.getFeatures.getFeatureMap.asScala assert(featureMap.size == rowArray.length) - assert( - featureMap("IntegerLabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER - ) + assert(featureMap("IntegerLabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER) assert(featureMap("IntegerLabel").getInt64List.getValue(0).toInt == 1) - assert( - featureMap("LongLabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER - ) + assert(featureMap("LongLabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER) assert(featureMap("LongLabel").getInt64List.getValue(0).toInt == 23) - assert( - featureMap("FloatLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER - ) + assert(featureMap("FloatLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) assert(featureMap("FloatLabel").getFloatList.getValue(0) == 10.0F) - assert( - featureMap("DoubleLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER - ) + assert(featureMap("DoubleLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) assert(featureMap("DoubleLabel").getFloatList.getValue(0) == 14.0F) - assert( - featureMap("DecimalLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER - ) + assert(featureMap("DecimalLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) assert(featureMap("DecimalLabel").getFloatList.getValue(0) == 6.5F) - assert( - featureMap("DoubleArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER - ) - assert( - featureMap("DoubleArrayLabel").getFloatList.getValueList.asScala.toSeq - .map(_.toFloat) ~== doubleArray.map(_.toFloat) - ) + assert(featureMap("DoubleArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) + assert(featureMap("DoubleArrayLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== doubleArray.map(_.toFloat)) - assert( - featureMap("DecimalArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER - ) - assert( - featureMap("DecimalArrayLabel").getFloatList.getValueList.asScala.toSeq - .map(_.toFloat) ~== decimalArray.map(_.toFloat) - ) + assert(featureMap("DecimalArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) + assert(featureMap("DecimalArrayLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== decimalArray.map(_.toFloat)) - assert( - featureMap("StrLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER - ) - assert( - featureMap("StrLabel").getBytesList.getValue(0).toStringUtf8 == "r1" - ) + assert(featureMap("StrLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER) + assert(featureMap("StrLabel").getBytesList.getValue(0).toStringUtf8 == "r1") - assert( - featureMap("StrArrayLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER - ) - assert( - featureMap("StrArrayLabel").getBytesList.getValueList.asScala - .map(_.toStringUtf8) === Seq("r2", "r3") - ) + assert(featureMap("StrArrayLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER) + assert(featureMap("StrArrayLabel").getBytesList.getValueList.asScala.map(_.toStringUtf8) === Seq("r2", "r3")) - assert( - featureMap("BinaryLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER - ) - assert( - featureMap("BinaryLabel").getBytesList - .getValue(0) - .toByteArray - .deep == byteArray.deep - ) + assert(featureMap("BinaryLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER) + assert(featureMap("BinaryLabel").getBytesList.getValue(0).toByteArray.deep == byteArray.deep) - assert( - featureMap("BinaryArrayLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER - ) - val binaryArrayValue = - featureMap("BinaryArrayLabel").getBytesList.getValueList.asScala - .map((byteArray) => byteArray.asScala.toArray.map(_.toByte)) + assert(featureMap("BinaryArrayLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER) + val binaryArrayValue = featureMap("BinaryArrayLabel").getBytesList.getValueList.asScala.map((byteArray) => byteArray.asScala.toArray.map(_.toByte)) assert(binaryArrayValue.toArray.deep == Array(byteArray, byteArray1).deep) } "Serialize internalRow to tfrecord sequenceExample" in { - val schemaStructType = StructType( - Array( - StructField("IntegerLabel", IntegerType), - StructField("StringArrayLabel", ArrayType(StringType)), - StructField("LongArrayOfArrayLabel", ArrayType(ArrayType(LongType))), - StructField( - "FloatArrayOfArrayLabel", - ArrayType(ArrayType(FloatType)) - ), - StructField( - "DoubleArrayOfArrayLabel", - ArrayType(ArrayType(DoubleType)) - ), - StructField( - "DecimalArrayOfArrayLabel", - ArrayType(ArrayType(DataTypes.createDecimalType())) - ), - StructField( - "StringArrayOfArrayLabel", - ArrayType(ArrayType(StringType)) - ), - StructField( - "BinaryArrayOfArrayLabel", - ArrayType(ArrayType(BinaryType)) - ) - ) - ) - - val stringList = Array( - UTF8String.fromString("r1"), - UTF8String.fromString("r2"), - UTF8String.fromString(("r3")) - ) + val schemaStructType = StructType(Array( + StructField("IntegerLabel", IntegerType), + StructField("StringArrayLabel", ArrayType(StringType)), + StructField("LongArrayOfArrayLabel", ArrayType(ArrayType(LongType))), + StructField("FloatArrayOfArrayLabel", ArrayType(ArrayType(FloatType))) , + StructField("DoubleArrayOfArrayLabel", ArrayType(ArrayType(DoubleType))), + StructField("DecimalArrayOfArrayLabel", ArrayType(ArrayType(DataTypes.createDecimalType()))), + StructField("StringArrayOfArrayLabel", ArrayType(ArrayType(StringType))), + StructField("BinaryArrayOfArrayLabel", ArrayType(ArrayType(BinaryType))) + )) + + val stringList = Array(UTF8String.fromString("r1"), UTF8String.fromString("r2"), UTF8String.fromString(("r3"))) val longListOfLists = Array(Array(3L, 5L), Array(-8L, 0L)) val floatListOfLists = Array(Array(1.5F, -6.5F), Array(-8.2F, 0F)) val doubleListOfLists = Array(Array(3.0), Array(6.0, 9.0)) - val decimalListOfLists = - Array(Array(Decimal(2.0), Decimal(4.0)), Array(Decimal(6.0))) - val stringListOfLists = Array( - Array(UTF8String.fromString("r1")), + val decimalListOfLists = Array(Array(Decimal(2.0), Decimal(4.0)), Array(Decimal(6.0))) + val stringListOfLists = Array(Array(UTF8String.fromString("r1")), Array(UTF8String.fromString("r2"), UTF8String.fromString("r3")), - Array(UTF8String.fromString("r4")) - ) - val binaryListOfLists = - stringListOfLists.map(stringList => stringList.map(_.getBytes)) + Array(UTF8String.fromString("r4"))) + val binaryListOfLists = stringListOfLists.map(stringList => stringList.map(_.getBytes)) - val rowArray = Array[Any]( - 10, + val rowArray = Array[Any](10, + createArray(UTF8String.fromString("r1"), UTF8String.fromString("r2"), UTF8String.fromString(("r3"))), + createArray( + createArray(3L, 5L), + createArray(-8L, 0L) + ), createArray( - UTF8String.fromString("r1"), - UTF8String.fromString("r2"), - UTF8String.fromString(("r3")) + createArray(1.5F, -6.5F), + createArray(-8.2F, 0F) + ), + createArray( + createArray(3.0), + createArray(6.0, 9.0) ), - createArray(createArray(3L, 5L), createArray(-8L, 0L)), - createArray(createArray(1.5F, -6.5F), createArray(-8.2F, 0F)), - createArray(createArray(3.0), createArray(6.0, 9.0)), createArray( createArray(Decimal(2.0), Decimal(4.0)), createArray(Decimal(6.0)) @@ -262,8 +174,7 @@ class TFRecordSerializerTest extends WordSpecLike with Matchers { createArray(UTF8String.fromString("r2"), UTF8String.fromString("r3")), createArray(UTF8String.fromString("r4")) ), - createArray( - createArray("r1".getBytes()), + createArray(createArray("r1".getBytes()), createArray("r2".getBytes(), "r3".getBytes), createArray("r4".getBytes()) ) @@ -279,96 +190,58 @@ class TFRecordSerializerTest extends WordSpecLike with Matchers { val featureListMap = tfexample.getFeatureLists.getFeatureListMap.asScala assert(featureMap.size == 2) - assert( - featureMap("IntegerLabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER - ) + assert(featureMap("IntegerLabel").getKindCase.getNumber == Feature.INT64_LIST_FIELD_NUMBER) assert(featureMap("IntegerLabel").getInt64List.getValue(0).toInt == 10) - assert( - featureMap("StringArrayLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER - ) - assert( - featureMap("StringArrayLabel").getBytesList.getValueList.asScala - .map(x => UTF8String.fromString(x.toStringUtf8)) === stringList - ) + assert(featureMap("StringArrayLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER) + assert(featureMap("StringArrayLabel").getBytesList.getValueList.asScala.map(x => UTF8String.fromString(x.toStringUtf8)) === stringList) assert(featureListMap.size == 6) - assert( - featureListMap("LongArrayOfArrayLabel").getFeatureList.asScala - .map(_.getInt64List.getValueList.asScala.toSeq) === longListOfLists - ) + assert(featureListMap("LongArrayOfArrayLabel").getFeatureList.asScala.map( + _.getInt64List.getValueList.asScala.toSeq) === longListOfLists) - assert( - featureListMap("FloatArrayOfArrayLabel").getFeatureList.asScala.map( - _.getFloatList.getValueList.asScala.map(_.toFloat).toSeq - ) ~== floatListOfLists.map { arr => - arr.toSeq - }.toSeq - ) - assert( - featureListMap("DoubleArrayOfArrayLabel").getFeatureList.asScala.map( - _.getFloatList.getValueList.asScala.map(_.toDouble).toSeq - ) ~== doubleListOfLists.map { arr => - arr.toSeq - }.toSeq - ) + assert(featureListMap("FloatArrayOfArrayLabel").getFeatureList.asScala.map( + _.getFloatList.getValueList.asScala.map(_.toFloat).toSeq) ~== floatListOfLists.map{arr => arr.toSeq}.toSeq) + assert(featureListMap("DoubleArrayOfArrayLabel").getFeatureList.asScala.map( + _.getFloatList.getValueList.asScala.map(_.toDouble).toSeq) ~== doubleListOfLists.map{arr => arr.toSeq}.toSeq) - assert( - featureListMap("DecimalArrayOfArrayLabel").getFeatureList.asScala.map( - _.getFloatList.getValueList.asScala - .map(x => Decimal(x.toDouble)) - .toSeq - ) ~== decimalListOfLists.map { arr => - arr.toSeq - }.toSeq - ) + assert(featureListMap("DecimalArrayOfArrayLabel").getFeatureList.asScala.map( + _.getFloatList.getValueList.asScala.map(x => Decimal(x.toDouble)).toSeq) ~== decimalListOfLists.map{arr => arr.toSeq}.toSeq) - assert( - featureListMap("StringArrayOfArrayLabel").getFeatureList.asScala.map( - _.getBytesList.getValueList.asScala - .map(x => UTF8String.fromString(x.toStringUtf8)) - .toSeq - ) === stringListOfLists - ) + assert(featureListMap("StringArrayOfArrayLabel").getFeatureList.asScala.map( + _.getBytesList.getValueList.asScala.map(x => UTF8String.fromString(x.toStringUtf8)).toSeq) === stringListOfLists) - assert( - featureListMap("BinaryArrayOfArrayLabel").getFeatureList.asScala.map( - _.getBytesList.getValueList.asScala - .map(byteList => byteList.asScala.toSeq) - ) === binaryListOfLists.map(_.map(_.toSeq)) - ) + assert(featureListMap("BinaryArrayOfArrayLabel").getFeatureList.asScala.map( + _.getBytesList.getValueList.asScala.map(byteList => byteList.asScala.toSeq)) === binaryListOfLists.map(_.map(_.toSeq))) } "Throw an exception for non-nullable data types" in { - val schemaStructType = StructType( - Array( - StructField("NonNullLabel", ArrayType(FloatType), nullable = false) - ) - ) + val schemaStructType = StructType(Array( + StructField("NonNullLabel", ArrayType(FloatType), nullable = false) + )) val internalRow = InternalRow.fromSeq(Array[Any](null)) val serializer = new TFRecordSerializer(schemaStructType) - intercept[NullPointerException] { + intercept[NullPointerException]{ serializer.serializeExample(internalRow) } - intercept[NullPointerException] { + intercept[NullPointerException]{ serializer.serializeSequenceExample(internalRow) } } "Omit null fields from Example for nullable data types" in { - val schemaStructType = StructType( - Array( - StructField("NullLabel", ArrayType(FloatType), nullable = true), - StructField("FloatArrayLabel", ArrayType(FloatType)) - ) - ) + val schemaStructType = StructType(Array( + StructField("NullLabel", ArrayType(FloatType), nullable = true), + StructField("FloatArrayLabel", ArrayType(FloatType)) + )) val floatArray = Array(2.5F, 5.0F) - val internalRow = - InternalRow.fromSeq(Array[Any](null, createArray(2.5F, 5.0F))) + val internalRow = InternalRow.fromSeq( + Array[Any](null, createArray(2.5F, 5.0F)) + ) val serializer = new TFRecordSerializer(schemaStructType) val tfexample = serializer.serializeExample(internalRow) @@ -376,108 +249,74 @@ class TFRecordSerializerTest extends WordSpecLike with Matchers { //Verify each Datatype converted to TensorFlow datatypes val featureMap = tfexample.getFeatures.getFeatureMap.asScala assert(featureMap.size == 1) - assert( - featureMap("FloatArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER - ) - assert( - featureMap("FloatArrayLabel").getFloatList.getValueList.asScala.toSeq - .map(_.toFloat) ~== floatArray.toSeq - ) + assert(featureMap("FloatArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) + assert(featureMap("FloatArrayLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== floatArray.toSeq) } "Omit null fields from SequenceExample for nullable data types" in { - val schemaStructType = StructType( - Array( - StructField("NullLabel", ArrayType(FloatType), nullable = true), - StructField("FloatArrayLabel", ArrayType(FloatType)) - ) - ) + val schemaStructType = StructType(Array( + StructField("NullLabel", ArrayType(FloatType), nullable = true), + StructField("FloatArrayLabel", ArrayType(FloatType)) + )) val floatArray = Array(2.5F, 5.0F) - val internalRow = - InternalRow.fromSeq(Array[Any](null, createArray(2.5F, 5.0F))) + val internalRow = InternalRow.fromSeq( + Array[Any](null, createArray(2.5F, 5.0F))) val serializer = new TFRecordSerializer(schemaStructType) val tfSeqExample = serializer.serializeSequenceExample(internalRow) //Verify each Datatype converted to TensorFlow datatypes val featureMap = tfSeqExample.getContext.getFeatureMap.asScala - val featureListMap = - tfSeqExample.getFeatureLists.getFeatureListMap.asScala + val featureListMap = tfSeqExample.getFeatureLists.getFeatureListMap.asScala assert(featureMap.size == 1) assert(featureListMap.isEmpty) - assert( - featureMap("FloatArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER - ) - assert( - featureMap("FloatArrayLabel").getFloatList.getValueList.asScala.toSeq - .map(_.toFloat) ~== floatArray.toSeq - ) + assert(featureMap("FloatArrayLabel").getKindCase.getNumber == Feature.FLOAT_LIST_FIELD_NUMBER) + assert(featureMap("FloatArrayLabel").getFloatList.getValueList.asScala.toSeq.map(_.toFloat) ~== floatArray.toSeq) } "Throw an exception for unsupported data types" in { - val schemaStructType = - StructType(Array(StructField("TimestampLabel", TimestampType))) + val schemaStructType = StructType(Array( + StructField("TimestampLabel", TimestampType) + )) - intercept[RuntimeException] { + intercept[RuntimeException]{ new TFRecordSerializer(schemaStructType) } } - val schema = StructType(Array(StructField("bytesLabel", BinaryType))) + val schema = StructType(Array( + StructField("bytesLabel", BinaryType)) + ) val serializer = new TFRecordSerializer(schema) "Test Int64ListFeature" in { val longFeature = serializer.Int64ListFeature(Seq(10L)) - val longListFeature = serializer.Int64ListFeature(Seq(3L, 5L, 6L)) + val longListFeature = serializer.Int64ListFeature(Seq(3L,5L,6L)) assert(longFeature.getInt64List.getValueList.asScala.toSeq === Seq(10L)) - assert( - longListFeature.getInt64List.getValueList.asScala.toSeq === Seq( - 3L, - 5L, - 6L - ) - ) + assert(longListFeature.getInt64List.getValueList.asScala.toSeq === Seq(3L, 5L, 6L)) } "Test floatListFeature" in { val floatFeature = serializer.floatListFeature(Seq(10.1F)) - val floatListFeature = serializer.floatListFeature(Seq(3.1F, 5.1F, 6.1F)) + val floatListFeature = serializer.floatListFeature(Seq(3.1F,5.1F,6.1F)) - assert( - floatFeature.getFloatList.getValueList.asScala.toSeq === Seq(10.1F) - ) - assert( - floatListFeature.getFloatList.getValueList.asScala.toSeq === Seq( - 3.1F, - 5.1F, - 6.1F - ) - ) + assert(floatFeature.getFloatList.getValueList.asScala.toSeq === Seq(10.1F)) + assert(floatListFeature.getFloatList.getValueList.asScala.toSeq === Seq(3.1F,5.1F,6.1F)) } "Test bytesListFeature" in { - val bytesFeature = - serializer.bytesListFeature(Seq(Array(0xff.toByte, 0xd8.toByte))) - val bytesListFeature = serializer.bytesListFeature( - Seq(Array(0xff.toByte, 0xd8.toByte), Array(0xff.toByte, 0xd9.toByte)) - ) - - assert( - bytesFeature.getBytesList.getValueList.asScala - .map(_.toByteArray.deep) === - Seq(Array(0xff.toByte, 0xd8.toByte).deep) - ) - assert( - bytesListFeature.getBytesList.getValueList.asScala - .map(_.toByteArray.deep) === - Seq( - Array(0xff.toByte, 0xd8.toByte).deep, - Array(0xff.toByte, 0xd9.toByte).deep - ) - ) + val bytesFeature = serializer.bytesListFeature(Seq(Array(0xff.toByte, 0xd8.toByte))) + val bytesListFeature = serializer.bytesListFeature(Seq( + Array(0xff.toByte, 0xd8.toByte), + Array(0xff.toByte, 0xd9.toByte))) + + assert(bytesFeature.getBytesList.getValueList.asScala.map(_.toByteArray.deep) === + Seq(Array(0xff.toByte, 0xd8.toByte).deep)) + assert(bytesListFeature.getBytesList.getValueList.asScala.map(_.toByteArray.deep) === + Seq(Array(0xff.toByte, 0xd8.toByte).deep, Array(0xff.toByte, 0xd9.toByte).deep)) } } } diff --git a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TestingUtils.scala b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TestingUtils.scala index d9e2e06..4a5bbea 100644 --- a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TestingUtils.scala +++ b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TestingUtils.scala @@ -1,26 +1,23 @@ /** - * Copyright 2016 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.linkedin.spark.datasources.tfrecord import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{ - GenericRowWithSchema, - SpecializedGetters -} +import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, SpecializedGetters} import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types._ import org.scalatest.Matchers @@ -28,94 +25,98 @@ import org.scalatest.Matchers object TestingUtils extends Matchers { /** - * Implicit class for comparing two double values using absolute tolerance. - */ + * Implicit class for comparing two double values using absolute tolerance. + */ implicit class FloatArrayWithAlmostEquals(val left: Seq[Float]) { /** - * When the difference of two values are within eps, returns true; otherwise, returns false. - */ - def ~==(right: Seq[Float], epsilon: Float = 1E-6F): Boolean = { + * When the difference of two values are within eps, returns true; otherwise, returns false. + */ + def ~==(right: Seq[Float], epsilon : Float = 1E-6F): Boolean = { if (left.size === right.size) { (left zip right) forall { case (a, b) => a === (b +- epsilon) } - } else false + } + else false } } /** - * Implicit class for comparing two double values using absolute tolerance. - */ + * Implicit class for comparing two double values using absolute tolerance. + */ implicit class DoubleArrayWithAlmostEquals(val left: Seq[Double]) { /** - * When the difference of two values are within eps, returns true; otherwise, returns false. - */ - def ~==(right: Seq[Double], epsilon: Double = 1E-6): Boolean = { + * When the difference of two values are within eps, returns true; otherwise, returns false. + */ + def ~==(right: Seq[Double], epsilon : Double = 1E-6): Boolean = { if (left.size === right.size) { (left zip right) forall { case (a, b) => a === (b +- epsilon) } - } else false + } + else false } } /** - * Implicit class for comparing two decimal values using absolute tolerance. - */ + * Implicit class for comparing two decimal values using absolute tolerance. + */ implicit class DecimalArrayWithAlmostEquals(val left: Seq[Decimal]) { /** - * When the difference of two values are within eps, returns true; otherwise, returns false. - */ - def ~==(right: Seq[Decimal], epsilon: Double = 1E-6): Boolean = { + * When the difference of two values are within eps, returns true; otherwise, returns false. + */ + def ~==(right: Seq[Decimal], epsilon : Double = 1E-6): Boolean = { if (left.size === right.size) { - (left zip right) forall { - case (a, b) => a.toDouble === (b.toDouble +- epsilon) - } - } else false + (left zip right) forall { case (a, b) => a.toDouble === (b.toDouble +- epsilon) } + } + else false } } /** - * Implicit class for comparing two double values using absolute tolerance. - */ + * Implicit class for comparing two double values using absolute tolerance. + */ implicit class FloatMatrixWithAlmostEquals(val left: Seq[Seq[Float]]) { /** - * When the difference of two values are within eps, returns true; otherwise, returns false. - */ - def ~==(right: Seq[Seq[Float]], epsilon: Float = 1E-6F): Boolean = { + * When the difference of two values are within eps, returns true; otherwise, returns false. + */ + def ~==(right: Seq[Seq[Float]], epsilon : Float = 1E-6F): Boolean = { if (left.size === right.size) { (left zip right) forall { case (a, b) => a ~== (b, epsilon) } - } else false + } + else false } } /** - * Implicit class for comparing two double values using absolute tolerance. - */ + * Implicit class for comparing two double values using absolute tolerance. + */ implicit class DoubleMatrixWithAlmostEquals(val left: Seq[Seq[Double]]) { /** - * When the difference of two values are within eps, returns true; otherwise, returns false. - */ - def ~==(right: Seq[Seq[Double]], epsilon: Double = 1E-6): Boolean = { + * When the difference of two values are within eps, returns true; otherwise, returns false. + */ + def ~==(right: Seq[Seq[Double]], epsilon : Double = 1E-6): Boolean = { if (left.size === right.size) { (left zip right) forall { case (a, b) => a ~== (b, epsilon) } - } else false + } + else false } } /** - * Implicit class for comparing two decimal values using absolute tolerance. - */ + * Implicit class for comparing two decimal values using absolute tolerance. + */ implicit class DecimalMatrixWithAlmostEquals(val left: Seq[Seq[Decimal]]) { /** - * When the difference of two values are within eps, returns true; otherwise, returns false. - */ - def ~==(right: Seq[Seq[Decimal]], epsilon: Double = 1E-6): Boolean = { + * When the difference of two values are within eps, returns true; otherwise, returns false. + */ + def ~==(right: Seq[Seq[Decimal]], epsilon : Double = 1E-6): Boolean = { if (left.size === right.size) { (left zip right) forall { case (a, b) => a ~== (b, epsilon) } - } else false + } + else false } } @@ -124,122 +125,93 @@ object TestingUtils extends Matchers { */ implicit class InternalRowWithAlmostEquals(val left: InternalRow) { - private type valueCompare = - (SpecializedGetters, SpecializedGetters, Int) => Boolean - private def newValueCompare(dataType: DataType, - epsilon: Float = 1E-6F): valueCompare = - dataType match { - case NullType => - (left, right, ordinal) => - left.get(ordinal, null) == right.get(ordinal, null) - - case IntegerType => - (left, right, ordinal) => - left.getInt(ordinal) === right.getInt(ordinal) - - case LongType => - (left, right, ordinal) => - left.getLong(ordinal) === right.getLong(ordinal) - - case FloatType => - (left, right, ordinal) => - left.getFloat(ordinal) === (right.getFloat(ordinal) +- epsilon) - - case DoubleType => - (left, right, ordinal) => - left.getDouble(ordinal) === (right.getDouble(ordinal) +- epsilon) - - case DecimalType() => - (left, right, ordinal) => - left - .getDecimal( - ordinal, - DecimalType.USER_DEFAULT.precision, - DecimalType.USER_DEFAULT.scale - ) - .toDouble === - (right - .getDecimal( - ordinal, - DecimalType.USER_DEFAULT.precision, - DecimalType.USER_DEFAULT.scale - ) - .toDouble - +- epsilon) - - case StringType => - (left, right, ordinal) => - left.getUTF8String(ordinal).getBytes === right - .getUTF8String(ordinal) - .getBytes - - case BinaryType => - (left, right, ordinal) => - left.getBinary(ordinal) === right.getBinary(ordinal) - - case ArrayType(elementType, _) => - (left, right, ordinal) => - if (left.get(ordinal, null) == null || right.get(ordinal, null) == null) { - left.get(ordinal, null) == right.get(ordinal, null) - } else { - val leftArray = left.getArray(ordinal) - val rightArray = right.getArray(ordinal) - if (leftArray.numElements == rightArray.numElements) { - val len = leftArray.numElements() - val elementValueCompare = newValueCompare(elementType) - var result = true - var idx: Int = 0 - while (idx < len && result) { - result = elementValueCompare(leftArray, rightArray, idx) - idx += 1 - } - result - } else false - } - case _ => - throw new RuntimeException( - s"Cannot convert field to unsupported data type ${dataType}" - ) - } + private type valueCompare = (SpecializedGetters, SpecializedGetters, Int) => Boolean + private def newValueCompare( + dataType: DataType, + epsilon : Float = 1E-6F): valueCompare = dataType match { + case NullType => (left, right, ordinal) => + left.get(ordinal, null) == right.get(ordinal, null) + + case IntegerType => (left, right, ordinal) => + left.getInt(ordinal) === right.getInt(ordinal) + + case LongType => (left, right, ordinal) => + left.getLong(ordinal) === right.getLong(ordinal) + + case FloatType => (left, right, ordinal) => + left.getFloat(ordinal) === (right.getFloat(ordinal) +- epsilon) + + case DoubleType => (left, right, ordinal) => + left.getDouble(ordinal) === (right.getDouble(ordinal) +- epsilon) + + case DecimalType() => (left, right, ordinal) => + left.getDecimal(ordinal, DecimalType.USER_DEFAULT.precision, DecimalType.USER_DEFAULT.scale).toDouble === + (right.getDecimal(ordinal, DecimalType.USER_DEFAULT.precision, DecimalType.USER_DEFAULT.scale).toDouble + +- epsilon) + + case StringType => (left, right, ordinal) => + left.getUTF8String(ordinal).getBytes === right.getUTF8String(ordinal).getBytes + + case BinaryType => (left, right, ordinal) => + left.getBinary(ordinal) === right.getBinary(ordinal) + + case ArrayType(elementType, _) => (left, right, ordinal) => + if (left.get(ordinal, null) == null || right.get(ordinal, null) == null ){ + left.get(ordinal, null) == right.get(ordinal, null) + } else { + val leftArray = left.getArray(ordinal) + val rightArray = right.getArray(ordinal) + if (leftArray.numElements == rightArray.numElements) { + val len = leftArray.numElements() + val elementValueCompare = newValueCompare(elementType) + var result = true + var idx: Int = 0 + while (idx < len && result) { + result = elementValueCompare(leftArray, rightArray, idx) + idx += 1 + } + result + } else false + } + case _ => throw new RuntimeException(s"Cannot convert field to unsupported data type ${dataType}") + } /** * When all fields in row with given schema are equal or are within eps, returns true; otherwise, returns false. */ - def ~==(right: InternalRow, - schema: StructType, - epsilon: Float = 1E-6F): Boolean = { + def ~==(right: InternalRow, schema: StructType, epsilon : Float = 1E-6F): Boolean = { if (schema != null && schema.fields.size == left.numFields && schema.fields.size == right.numFields) { schema.fields.map(_.dataType).zipWithIndex.forall { case (dataType, idx) => - val valueCompare = newValueCompare(dataType) - valueCompare(left, right, idx) + val valueCompare = newValueCompare(dataType) + valueCompare(left, right, idx) } - } else false + } + else false } } /** - * Implicit class for comparing two rows using absolute tolerance. - */ + * Implicit class for comparing two rows using absolute tolerance. + */ implicit class RowWithAlmostEquals(val left: Row) { /** - * When all fields in row with given schema are equal or are within eps, returns true; otherwise, returns false. - */ + * When all fields in row with given schema are equal or are within eps, returns true; otherwise, returns false. + */ def ~==(right: Row, schema: StructType): Boolean = { if (schema != null && schema.fields.size == left.size && schema.fields.size == right.size) { - val leftRowWithSchema = - new GenericRowWithSchema(left.toSeq.toArray, schema) - val rightRowWithSchema = - new GenericRowWithSchema(right.toSeq.toArray, schema) + val leftRowWithSchema = new GenericRowWithSchema(left.toSeq.toArray, schema) + val rightRowWithSchema = new GenericRowWithSchema(right.toSeq.toArray, schema) leftRowWithSchema ~== rightRowWithSchema - } else false + } + else false } /** - * When all fields in row are equal or are within eps, returns true; otherwise, returns false. - */ - def ~==(right: Row, epsilon: Float = 1E-6F): Boolean = { + * When all fields in row are equal or are within eps, returns true; otherwise, returns false. + */ + def ~==(right: Row, epsilon : Float = 1E-6F): Boolean = { if (left.size === right.size) { val leftDataTypes = left.schema.fields.map(_.dataType) val rightDataTypes = right.schema.fields.map(_.dataType) @@ -255,99 +227,54 @@ object TestingUtils extends Matchers { left.getDouble(i) === (right.getDouble(i) +- epsilon) case ((BinaryType, BinaryType), i) => - left.getAs[Array[Byte]](i).toSeq === right - .getAs[Array[Byte]](i) - .toSeq - - case ((ArrayType(FloatType, _), ArrayType(FloatType, _)), i) => - val leftArray = - ArrayData.toArrayData(left.get(i)).toFloatArray().toSeq - val rightArray = - ArrayData.toArrayData(right.get(i)).toFloatArray().toSeq + left.getAs[Array[Byte]](i).toSeq === right.getAs[Array[Byte]](i).toSeq + + case ((ArrayType(FloatType,_), ArrayType(FloatType,_)), i) => + val leftArray = ArrayData.toArrayData(left.get(i)).toFloatArray().toSeq + val rightArray = ArrayData.toArrayData(right.get(i)).toFloatArray().toSeq leftArray ~== (rightArray, epsilon) - case ((ArrayType(DoubleType, _), ArrayType(DoubleType, _)), i) => - val leftArray = - ArrayData.toArrayData(left.get(i)).toDoubleArray().toSeq - val rightArray = - ArrayData.toArrayData(right.get(i)).toDoubleArray().toSeq + case ((ArrayType(DoubleType,_), ArrayType(DoubleType,_)), i) => + val leftArray = ArrayData.toArrayData(left.get(i)).toDoubleArray().toSeq + val rightArray = ArrayData.toArrayData(right.get(i)).toDoubleArray().toSeq leftArray ~== (rightArray, epsilon) - case ((ArrayType(BinaryType, _), ArrayType(BinaryType, _)), i) => - val leftArray = ArrayData - .toArrayData(left.get(i)) - .toArray[Array[Byte]](BinaryType) - .map(_.toSeq) - .toSeq - val rightArray = ArrayData - .toArrayData(right.get(i)) - .toArray[Array[Byte]](BinaryType) - .map(_.toSeq) - .toSeq + case ((ArrayType(BinaryType,_), ArrayType(BinaryType,_)), i) => + val leftArray = ArrayData.toArrayData(left.get(i)).toArray[Array[Byte]](BinaryType).map(_.toSeq).toSeq + val rightArray = ArrayData.toArrayData(right.get(i)).toArray[Array[Byte]](BinaryType).map(_.toSeq).toSeq leftArray === rightArray - case ( - ( - ArrayType(ArrayType(FloatType, _), _), - ArrayType(ArrayType(FloatType, _), _) - ), - i - ) => - val leftArrays = - ArrayData.toArrayData(left.get(i)).array.toSeq.map { arr => - ArrayData.toArrayData(arr).toFloatArray().toSeq - } - val rightArrays = - ArrayData.toArrayData(right.get(i)).array.toSeq.map { arr => - ArrayData.toArrayData(arr).toFloatArray().toSeq - } + case ((ArrayType(ArrayType(FloatType,_),_), ArrayType(ArrayType(FloatType,_),_)), i) => + val leftArrays = ArrayData.toArrayData(left.get(i)).array.toSeq.map {arr => + ArrayData.toArrayData(arr).toFloatArray().toSeq + } + val rightArrays = ArrayData.toArrayData(right.get(i)).array.toSeq.map {arr => + ArrayData.toArrayData(arr).toFloatArray().toSeq + } leftArrays ~== (rightArrays, epsilon) - case ( - ( - ArrayType(ArrayType(DoubleType, _), _), - ArrayType(ArrayType(DoubleType, _), _) - ), - i - ) => - val leftArrays = - ArrayData.toArrayData(left.get(i)).array.toSeq.map { arr => - ArrayData.toArrayData(arr).toDoubleArray().toSeq - } - val rightArrays = - ArrayData.toArrayData(right.get(i)).array.toSeq.map { arr => - ArrayData.toArrayData(arr).toDoubleArray().toSeq - } + case ((ArrayType(ArrayType(DoubleType,_),_), ArrayType(ArrayType(DoubleType,_),_)), i) => + val leftArrays = ArrayData.toArrayData(left.get(i)).array.toSeq.map {arr => + ArrayData.toArrayData(arr).toDoubleArray().toSeq + } + val rightArrays = ArrayData.toArrayData(right.get(i)).array.toSeq.map {arr => + ArrayData.toArrayData(arr).toDoubleArray().toSeq + } leftArrays ~== (rightArrays, epsilon) - case ( - ( - ArrayType(ArrayType(BinaryType, _), _), - ArrayType(ArrayType(BinaryType, _), _) - ), - i - ) => - val leftArrays = - ArrayData.toArrayData(left.get(i)).array.toSeq.map { arr => - ArrayData - .toArrayData(arr) - .toArray[Array[Byte]](BinaryType) - .map(_.toSeq) - .toSeq - } - val rightArrays = - ArrayData.toArrayData(right.get(i)).array.toSeq.map { arr => - ArrayData - .toArrayData(arr) - .toArray[Array[Byte]](BinaryType) - .map(_.toSeq) - .toSeq - } + case ((ArrayType(ArrayType(BinaryType,_),_), ArrayType(ArrayType(BinaryType,_),_)), i) => + val leftArrays = ArrayData.toArrayData(left.get(i)).array.toSeq.map {arr => + ArrayData.toArrayData(arr).toArray[Array[Byte]](BinaryType).map(_.toSeq).toSeq + } + val rightArrays = ArrayData.toArrayData(right.get(i)).array.toSeq.map {arr => + ArrayData.toArrayData(arr).toArray[Array[Byte]](BinaryType).map(_.toSeq).toSeq + } leftArrays === rightArrays - case ((a, b), i) => left.get(i) === right.get(i) + case((a,b), i) => left.get(i) === right.get(i) } - } else false + } + else false } } -} +} \ No newline at end of file From 36a5ae77777f1a9d343e2d2da4fca79ba2350f77 Mon Sep 17 00:00:00 2001 From: Greg Roodt Date: Mon, 31 Aug 2020 13:28:36 +1000 Subject: [PATCH 7/7] Update README. --- README.md | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index abbd9ff..33406d4 100644 --- a/README.md +++ b/README.md @@ -6,10 +6,9 @@ The implementation is based on [Spark Tensorflow Connector](https://github.com/t ## Including the library The artifacts are published to [bintray](https://bintray.com/linkedin/maven/spark-tfrecord) and [maven central](https://search.maven.org/search?q=spark-tfrecord) repositories. -Current releases were built with scala-2.11. -- Version 0.1.x is based on Spark 2.3. -- Version 0.2.x is based on Spark 3.0. +- Version 0.1.x targets Spark 2.3 and Scala 2.11 +- Version 0.2.x targets Spark 2.4 and both Scala 2.11 and 2.12 To use the package, please include the dependency as follows @@ -28,14 +27,13 @@ The library can be built with Maven 3.3.9 or newer as shown below: # Build Spark-TFRecord git clone https://github.com/linkedin/spark-tfrecord.git cd spark-tfrecord -mvn clean install +mvn -Pscala-2.11 clean install # One can specify the spark version and tensorflow hadoop version, for example -mvn clean install -Dspark.version=2.4.6 -Dtensorflow.hadoop.version=1.15.0 +mvn -Pscala-2.11 clean install -Dspark.version=2.4.6 -Dtensorflow.hadoop.version=1.15.0 # Or for building with Spark 3, use the following -mvn clean install -Dspark.version=3.0.0 -Dscala.binary.version=2.12 -Dscala.compiler.version=2.12.11 -Dscala.test.version=3.0.0 -# In this instance we would suggest changing the `` in the `pom.xml` to spark-tfrecord_2.12 +mvn -Pscala-2.12 clean install -Dspark.version=3.0.0 ``` ## Using Spark Shell