|
2 | 2 |
|
3 | 3 | import org.apache.log4j.Level; |
4 | 4 | import org.apache.log4j.Logger; |
5 | | -import org.apache.spark.SparkConf; |
6 | | -import org.apache.spark.api.java.JavaRDD; |
7 | | -import org.apache.spark.api.java.JavaSparkContext; |
8 | | -import org.apache.spark.sql.Dataset; |
9 | | -import org.apache.spark.sql.Encoders; |
10 | | -import org.apache.spark.sql.SparkSession; |
| 5 | +import org.apache.spark.sql.*; |
11 | 6 |
|
12 | 7 | import static org.apache.spark.sql.functions.avg; |
| 8 | +import static org.apache.spark.sql.functions.col; |
13 | 9 | import static org.apache.spark.sql.functions.max; |
14 | 10 |
|
15 | 11 |
|
16 | 12 | public class TypedDataset { |
17 | 13 | private static final String AGE_MIDPOINT = "ageMidpoint"; |
18 | 14 | private static final String SALARY_MIDPOINT = "salaryMidPoint"; |
19 | 15 | private static final String SALARY_MIDPOINT_BUCKET = "salaryMidpointBucket"; |
20 | | - private static final float NULL_VALUE = -1.0f; |
21 | | - private static final String COMMA_DELIMITER = ",(?=([^\"]*\"[^\"]*\")*[^\"]*$)"; |
22 | 16 |
|
23 | 17 | public static void main(String[] args) throws Exception { |
24 | 18 |
|
25 | 19 | Logger.getLogger("org").setLevel(Level.ERROR); |
26 | | - SparkConf conf = new SparkConf().setAppName("StackOverFlowSurvey").setMaster("local[1]"); |
| 20 | + SparkSession session = SparkSession.builder().appName("StackOverFlowSurvey").master("local[1]").getOrCreate(); |
27 | 21 |
|
28 | | - JavaSparkContext sc = new JavaSparkContext(conf); |
| 22 | + DataFrameReader dataFrameReader = session.read(); |
29 | 23 |
|
30 | | - SparkSession session = SparkSession.builder().appName("StackOverFlowSurvey").master("local[1]").getOrCreate(); |
| 24 | + Dataset<Row> responses = dataFrameReader.option("header","true").csv("in/2016-stack-overflow-survey-responses.csv"); |
31 | 25 |
|
32 | | - JavaRDD<String> lines = sc.textFile("in/2016-stack-overflow-survey-responses.csv"); |
| 26 | + Dataset<Row> responseWithSelectedColumns = responses.select(col("country"), col("age_midpoint").as("ageMidPoint").cast("integer"), col("occupation"), col("salary_midpoint").as("salaryMidPoint").cast("integer")); |
33 | 27 |
|
34 | | - JavaRDD<Response> responseRDD = lines |
35 | | - .filter(line -> !line.split(COMMA_DELIMITER, -1)[2].equals("country")) |
36 | | - .map(line -> { |
37 | | - String[] splits = line.split(COMMA_DELIMITER, -1); |
38 | | - return new Response(splits[2], convertStringToFloat(splits[6]), splits[9], convertStringToFloat(splits[14])); |
39 | | - }); |
40 | | - Dataset<Response> responseDataset = session.createDataset(responseRDD.rdd(), Encoders.bean(Response.class)); |
| 28 | + Dataset<Response> typedDataset = responseWithSelectedColumns.as(Encoders.bean(Response.class)); |
41 | 29 |
|
42 | 30 | System.out.println("=== Print out schema ==="); |
43 | | - responseDataset.printSchema(); |
| 31 | + typedDataset.printSchema(); |
44 | 32 |
|
45 | 33 | System.out.println("=== Print 20 records of responses table ==="); |
46 | | - responseDataset.show(20); |
| 34 | + typedDataset.show(20); |
47 | 35 |
|
48 | 36 | System.out.println("=== Print records where the response is from Afghanistan ==="); |
49 | | - responseDataset.filter(response -> response.getCountry().equals("Afghanistan")).show(); |
| 37 | + typedDataset.filter(response -> response.getCountry().equals("Afghanistan")).show(); |
50 | 38 |
|
51 | 39 | System.out.println("=== Print the count of occupations ==="); |
52 | | - responseDataset.groupBy(responseDataset.col("occupation")).count().show(); |
53 | | - |
| 40 | + typedDataset.groupBy(typedDataset.col("occupation")).count().show(); |
54 | 41 |
|
55 | 42 | System.out.println("=== Print records with average mid age less than 20 ==="); |
56 | | - responseDataset.filter(response -> response.getAgeMidPoint() != NULL_VALUE && response.getAgeMidPoint() < 20).show(); |
| 43 | + typedDataset.filter(response -> response.getAgeMidPoint() !=null && response.getAgeMidPoint() < 20).show(); |
57 | 44 |
|
58 | 45 | System.out.println("=== Print the result with salary middle point in descending order ==="); |
59 | | - responseDataset.orderBy(responseDataset.col(SALARY_MIDPOINT ).desc()).show(); |
| 46 | + typedDataset.orderBy(typedDataset.col(SALARY_MIDPOINT ).desc()).show(); |
60 | 47 |
|
61 | 48 | System.out.println("=== Group by country and aggregate by average salary middle point and max age middle point ==="); |
62 | | - responseDataset |
63 | | - .filter(response -> response.getSalaryMidPoint() != NULL_VALUE) |
64 | | - .groupBy("country") |
65 | | - .agg(avg(SALARY_MIDPOINT), max(AGE_MIDPOINT)) |
66 | | - .show(); |
| 49 | + typedDataset.filter(response -> response.getSalaryMidPoint() != null) |
| 50 | + .groupBy("country") |
| 51 | + .agg(avg(SALARY_MIDPOINT), max(AGE_MIDPOINT)) |
| 52 | + .show(); |
67 | 53 |
|
68 | 54 | System.out.println("=== Group by salary bucket ==="); |
69 | | - |
70 | | - responseDataset |
71 | | - .map(response -> Math.round(response.getSalaryMidPoint()/20000) * 20000, Encoders.INT()) |
72 | | - .withColumnRenamed("value", SALARY_MIDPOINT_BUCKET) |
73 | | - .groupBy(SALARY_MIDPOINT_BUCKET) |
74 | | - .count() |
75 | | - .orderBy(SALARY_MIDPOINT_BUCKET).show(); |
| 55 | + typedDataset.filter(response -> response.getSalaryMidPoint() != null) |
| 56 | + .map(response -> Math.round(response.getSalaryMidPoint()/20000) * 20000, Encoders.INT()) |
| 57 | + .withColumnRenamed("value", SALARY_MIDPOINT_BUCKET) |
| 58 | + .groupBy(SALARY_MIDPOINT_BUCKET) |
| 59 | + .count() |
| 60 | + .orderBy(SALARY_MIDPOINT_BUCKET).show(); |
76 | 61 | } |
77 | | - |
78 | | - private static float convertStringToFloat(String split) { |
79 | | - return split.isEmpty() ? NULL_VALUE : Float.valueOf(split); |
80 | | - } |
81 | | - |
82 | 62 | } |
0 commit comments