How to Use collect in Spark Scala - Retrieve DataFrame Rows

Using collect in Spark Scala

The collect function in Apache Spark is used to retrieve all rows from a DataFrame as an array. This operation is useful for retrieving data to the driver node for further processing in local memory. In this tutorial, we will cover how to use the collect function in Spark using Scala with practical examples.

1. Sample Data

We will start by creating a sample DataFrame to demonstrate the collect operation. Here's the data we'll use:

Roll First Name Age Last Name
1 Rahul 30 Yadav
2 Sanjay 20 gupta
3 Ranjan 67 kumar

2. Importing Necessary Libraries

We need to import the necessary Spark libraries:

import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}    

1. Creating a DataFrame

  val schema = StructType(Array(
        StructField("roll", IntegerType, true),
        StructField("first_name", StringType, true),
        StructField("age", IntegerType, true),
        StructField("last_name", StringType, true)
    ))
    val data = Seq(
        Row(1, "rahul", 30, "yadav"),
        Row(2, "sanjay", 20, "gupta"),
        Row(3, "ranjan", 67, "kumar"),
    )

3. Using the collect Function

The collect function allows you to retrieve all rows from a DataFrame as an array. Here’s how you can use it:

val collectedRows:Array[Row]=testDF.collect()
collectedRows.foreach( data => println(data) )

3. Store column value into a string variable

The collect function allows you to retrieve all rows from a DataFrame as an array.

//get the name of student who's roll number is 2
val firstName:String=testDF.filter(col("roll")===lit(2) ).collect()(0)(1).toString
println(firstName)

Compete Code

import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

object CollectInSpark {
    
      def main(args: Array[String]): Unit = {
        val sparkSession = SparkSession
          .builder()
          .appName("Collect in spark scala")
          .master("local")
          .getOrCreate()
        val schema = StructType(Array(
          StructField("roll", IntegerType, true),
          StructField("first_name", StringType, true),
          StructField("age", IntegerType, true),
          StructField("last_name", StringType, true)
        ))
        val data = Seq(
          Row(1, "rahul", 30, "yadav"),
          Row(2, "sanjay", 20, "gupta"),
          Row(3, "ranjan", 67, "kumar"),
        )
        val rdd = sparkSession.sparkContext.parallelize(data)
        val testDF = sparkSession.createDataFrame(rdd, schema)
    
        val collectedRows:Array[Row]=testDF.collect()
        collectedRows.foreach( data => println(data) )
    
         //get the name of student who's roll number is 2
        val firstName:String=testDF.filter(col("roll")===lit(2) ).collect()(0)(1).toString
        println(firstName)
    
    
      }
    
}
    

Output

Alps