How to Use collect in Spark Scala - Retrieve DataFrame Rows
The collect
function in Apache Spark is used to retrieve all rows from a DataFrame as an array. This operation is useful for retrieving data to the driver node for further processing in local memory. In this tutorial, we will cover how to use the collect
function in Spark using Scala with practical examples.
Roll | First Name | Age | Last Name |
---|---|---|---|
1 | Rahul | 30 | Yadav |
2 | Sanjay | 20 | gupta |
3 | Ranjan | 67 | kumar |
We need to import the necessary Spark libraries:
import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
val schema = StructType(Array( StructField("roll", IntegerType, true), StructField("first_name", StringType, true), StructField("age", IntegerType, true), StructField("last_name", StringType, true) )) val data = Seq( Row(1, "rahul", 30, "yadav"), Row(2, "sanjay", 20, "gupta"), Row(3, "ranjan", 67, "kumar"), )
The collect
function allows you to retrieve all rows from a DataFrame as an array. Here’s how you can use it:
val collectedRows:Array[Row]=testDF.collect() collectedRows.foreach( data => println(data) )
The collect
function allows you to retrieve all rows from a DataFrame as an array.
//get the name of student who's roll number is 2 val firstName:String=testDF.filter(col("roll")===lit(2) ).collect()(0)(1).toString println(firstName)
import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} object CollectInSpark { def main(args: Array[String]): Unit = { val sparkSession = SparkSession .builder() .appName("Collect in spark scala") .master("local") .getOrCreate() val schema = StructType(Array( StructField("roll", IntegerType, true), StructField("first_name", StringType, true), StructField("age", IntegerType, true), StructField("last_name", StringType, true) )) val data = Seq( Row(1, "rahul", 30, "yadav"), Row(2, "sanjay", 20, "gupta"), Row(3, "ranjan", 67, "kumar"), ) val rdd = sparkSession.sparkContext.parallelize(data) val testDF = sparkSession.createDataFrame(rdd, schema) val collectedRows:Array[Row]=testDF.collect() collectedRows.foreach( data => println(data) ) //get the name of student who's roll number is 2 val firstName:String=testDF.filter(col("roll")===lit(2) ).collect()(0)(1).toString println(firstName) } }