How to use groupBy with multiple columns in Spark Scala

How to Use groupBy in Spark Scala - Grouping and Aggregating Data

Grouping and aggregating data is a fundamental part of data analysis. In Apache Spark, you can use the groupBy function to group DataFrame data in Scala. This tutorial will guide you through the process of using this function with practical examples and explanations.

For example I have considered below sample data

Sample Data


Roll First Name Age Last Name subject Marks
1 Rahul 18 Yadav PHYSICS 80
1 Rahul 18 Yadav CHEMISTRY 77
1 Rahul 18 Yadav BIOLOGY 70
2 Vinay 17 kumar PHYSICS 80
2 Vinay 17 kumar CHEMISTRY 77
2 Vinay 17 kumar BIOLOGY 66

Step 1: Import Required Libraries

First, you need to import the necessary libraries:

import org.apache.spark.sql.{Row, SparkSession, functions}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

Step 2: Create Sample DataFrame

For demonstration purposes, let's create a sample DataFrame:

  val schema = StructType(Array(
    StructField("roll", IntegerType, true),
    StructField("first_name", StringType, true),
    StructField("age", IntegerType, true),
    StructField("last_name", StringType, true),
    StructField("subject", StringType, true),
    StructField("Marks", IntegerType, true)
  ))
  val data = Seq(
    Row(1, "rahul", 18, "yadav","PHYSICS",80),
    Row(1, "rahul", 18, "yadav","CHEMISTRY",77),
    Row(1, "rahul", 18, "yadav","BIOLOGY",70),
    Row(2, "Vinay", 17, "kumar","PHYSICS",80),
    Row(2, "Vinay", 17, "kumar","CHEMISTRY",77),
    Row(2, "Vinay", 17, "kumar","BIOLOGY",66),
  )
  val rdd = sparkSession.sparkContext.parallelize(data)
  val testDF = sparkSession.createDataFrame(rdd, schema)    

Step 3: Use groupBy for single column

  val groupedDF=testDF.groupBy("roll").agg( functions.sum("Marks").as("total_marks") )

Step 3: Use Multiple Column in groupby

    val groupedDF=testDF.groupBy("roll","first_name","last_name")
    .agg( functions.sum("Marks").as("total_marks") )
    
  

Complete Code

  import org.apache.spark.sql.{Row, SparkSession, functions}
  import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
  
  object GroupByInScalaSpark {
  
    def main(args: Array[String]): Unit = {
      val sparkSession = SparkSession
        .builder()
        .appName("group by in scala spark dataframe")
        .master("local")
        .getOrCreate()
      val schema = StructType(Array(
        StructField("roll", IntegerType, true),
        StructField("first_name", StringType, true),
        StructField("age", IntegerType, true),
        StructField("last_name", StringType, true),
        StructField("subject", StringType, true),
        StructField("Marks", IntegerType, true)
      ))
      val data = Seq(
        Row(1, "rahul", 18, "yadav","PHYSICS",80),
        Row(1, "rahul", 18, "yadav","CHEMISTRY",77),
        Row(1, "rahul", 18, "yadav","BIOLOGY",70),
        Row(2, "Vinay", 17, "kumar","PHYSICS",80),
        Row(2, "Vinay", 17, "kumar","CHEMISTRY",77),
        Row(2, "Vinay", 17, "kumar","BIOLOGY",66),
      )
      val rdd = sparkSession.sparkContext.parallelize(data)
      val testDF = sparkSession.createDataFrame(rdd, schema)
      val groupedDF=testDF.groupBy("roll")
        .agg( functions.sum("Marks").as("total_marks") )
  
      groupedDF.show()
      sparkSession.stop()
  
    }
  
  }     

That's it! You've successfully applied withColumnRenamed to a DataFrame in Spark using Scala.

Output

Alps