Filter and Where Conditions in Spark DataFrame Scala

Filter and Where Conditions in Spark DataFrame - Scala

Learn how to use filter and where conditions when working with Spark DataFrames using Scala. This tutorial will guide you through the process of applying conditional logic to your data filtering, allowing you to retrieve specific subsets of data based on given criteria.

For example I have considered below sample data

Sample Data

Roll First Name Age Last Name
1 Rahul 30 Yadav
2 Sanjay 20 gupta
3 Ranjan 67 kumar
4 ravi 67 kumar

Step 1: Import Required Libraries

First, you need to import the necessary libraries:

import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}

Step 2: Create Sample DataFrame

For demonstration purposes, let's create a sample DataFrame:

val schema = StructType(Array(
  StructField("roll", IntegerType, true),
  StructField("first_name", StringType, true),
  StructField("age", IntegerType, true),
   StructField("last_name", StringType, true)
))
val data = Seq(
  Row(1, "rahul", 30, "yadav"),
  Row(2, "sanjay", 20, "gupta"),
  Row(3, "ranjan", 67, "kumar"),
  Row(4, "ravi", 67, "kumar")
)
val rdd = sparkSession.sparkContext.parallelize(data)
val testDF = sparkSession.createDataFrame(rdd, schema)
        

Step 3: Apply Filter and Where Conditions

Now, let's apply filter and where conditions to the DataFrame:

Using Filter:

val transformedDF=testDF.filter(col("age")>lit(30) )  

Using Where:

val transformedDF=testDF.where(col("age")>lit(30))

Complete Code

  import org.apache.spark.sql.functions.{col, lit}
  import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
  import org.apache.spark.sql.{Row, SparkSession}
  
  object ApplyFilter {
  
    def main(args: Array[String]): Unit = {
      val sparkSession = SparkSession
        .builder()
        .appName("Add new column to spark dataframe")
        .master("local")
        .getOrCreate()
      val schema = StructType(Array(
        StructField("roll", IntegerType, true),
        StructField("first_name", StringType, true),
        StructField("age", IntegerType, true),
        StructField("last_name", StringType, true)
      ))
      val data = Seq(
        Row(1, "rahul", 30, "yadav"),
        Row(2, "sanjay", 20, "gupta"),
        Row(3, "ranjan", 67, "kumar"),
        Row(4, "ravi", 67, "kumar")
      )
      val rdd = sparkSession.sparkContext.parallelize(data)
      val testDF = sparkSession.createDataFrame(rdd, schema)
      val transformedDF=testDF.filter(col("age")>lit(30))
      transformedDF.show()
      sparkSession.stop()
  
    }
  
  }
  

That's it! You've successfully applied filter and where conditions to a DataFrame in Spark using Scala.

Output

Alps