Save column value into string variable - PySpark

Store column value into string variable PySpark - Collect

The collect function in Apache PySpark is used to retrieve all rows from a DataFrame as an array. This operation is useful for retrieving data to the driver node for further processing in local memory. In this tutorial, we will cover how to use the collect function in PySpark with practical examples.

1. Sample Data

We will start by creating a sample DataFrame to demonstrate the collect operation. Here's the data we'll use:

Roll First Name Age Last Name
1 Rahul 30 Yadav
2 Sanjay 20 gupta
3 Ranjan 67 kumar

2. Importing Necessary Libraries

We need to import the necessary PySpark libraries:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit
from pyspark.sql import Row

1. Creating aDataFrame

# Define the schema
schema = "roll INT, first_name STRING, age INT, last_name STRING"

# Create the data
data = [
Row(1, "rahul", 30, "yadav"),
Row(2, "sanjay", 20, "gupta"),
Row(3, "ranjan", 67, "kumar"),
]
# Create the DataFrame
test_df = spark.createDataFrame(data, schema=schema)

3. Store column value into a string variable

The collect function allows you to retrieve all rows from a DataFrame as an array.

# Get the name of the student whose roll number is 2
first_name = test_df.filter(col("roll") == lit(2)).collect()[0][1]
print(first_name)

Compete Code

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit
from pyspark.sql import Row

# Initialize SparkSession
spark = SparkSession.builder \
.appName("Collect in PySpark") \
.master("local") \
.getOrCreate()

# Define the schema
schema = "roll INT, first_name STRING, age INT, last_name STRING"

# Create the data
data = [
Row(1, "rahul", 30, "yadav"),
Row(2, "sanjay", 20, "gupta"),
Row(3, "ranjan", 67, "kumar"),
]

# Create the DataFrame
test_df = spark.createDataFrame(data, schema=schema)

# Collect all rows from the DataFrame
collected_rows = test_df.collect()
for row in collected_rows:
print(row)

# Get the name of the student whose roll number is 2
first_name = test_df.filter(col("roll") == lit(2)).collect()[0][1]
print(first_name)

# Stop the SparkSession
spark.stop()


Output:

Alps