PySpark Beginner’s Guide

by Alex
PySpark Beginner's Guide

PySpark is an Apache Spark API, which is an open source system used for distributed big data processing. It was originally developed in the Scala programming language at the University of California, Berkeley. Spark provides APIs for Scala, Java, Python and R. The system supports code reuse between work tasks, batch data processing, interactive queries, real-time analytics, machine learning, and graph computing. It uses in-memory caching and optimized query execution for data of any size. It does not have one proprietary file system such as Hadoop Distributed File System (HDFS), instead Spark supports many popular file systems such as HDFS, HBase, Cassandra, Amazon S3, Amazon Redshift, Couchbase, etc. Benefits of using Apache Spark:

  • It runs programs in memory up to 100 times faster than Hadoop MapReduce, and 10 times faster on disk, because Spark performs processing in the main memory of working nodes and prevents unnecessary I/O operations.
  • Spark is extremely user-friendly because it has APIs written in popular languages, making it easy for developers: this approach hides the complexity of distributed processing behind simple high-level operators, which greatly reduces the amount of code needed.
  • The system can be deployed using Mesos, Hadoop via Yarn, or Spark’s own cluster manager.
  • Spark performs calculations in real time and provides low latency due to their resident (in-memory) execution.

Let’s get started.

Setting up the environment in Google Colab

To run pyspark on the local machine, we need Java and some other software. So instead of a complicated installation procedure, we use Google Colaboratory, which perfectly meets our hardware requirements and also comes with a wide range of libraries for data analysis and machine learning. Thus, all we have to do is install the pyspark and Py4J packages. Py4J allows Python programs running in the Python interpreter to dynamically access Java objects from the Java virtual machine. The final notebook can be downloaded from the repository: Command to install the above packages: !pip install pyspark==3.0.1 py4j==0.10.9

Spark Session

SparkSession has been the entry point into PySpark since version 2.0: previously SparkContext was used for this. SparkSession is a way to initialize basic PySpark functionality to programmatically create PySpark RDDs, DataFrames and Datasets. It can be used instead of SQLContext, HiveContext and other contexts defined before 2.0. You should also be aware that SparkSession internally creates SparkConfig and SparkContext with the configuration provided with SparkSession. You can create a SparkSession with SparkSession.builder, which is an implementation of the Builder design pattern.

Creating a SparkSession

To create a SparkSession, you need to use the builder() method.

  • getOrCreate() returns an existing SparkSession; if it doesn’t exist, a new SparkSession is created.
  • master(): if you are working with a cluster, you need to pass the name of your cluster manager as an argument. Normally this will be either yarn or mesos depending on your cluster setup, and when working offline, local[x] is used. Here, x should be an integer greater than 0. This value indicates how many partitions will be created when using RDD, DataFrame and Dataset. Ideally X should correspond to the number of CPU cores.
  • appName() is used to set the name of your application.

An example of creating a SparkSession:
from pyspark.sql import SparkSession
spark = SparkSession.builder
# where "*" stands for all processor cores.

Read data

Using we can read data from various file formats like CSV, JSON, Parquet, and others. Here are some examples of getting data from files:
# Reading a CSV file
csv_file = ‘data/stocks_price_final.csv’
df =
# read JSON file
json_file = ‘data/stocks_price_final.json’
data =
# Read parquet file
parquet_file = ‘data/stocks_price_final.parquet’
data1 =

Structuring data with Spark schema

Let’s read the U.S. stock price data from January 2019 to July 2020, which is available in the Kaggle dataset. The code to read the data is in CSV file format:
data =
Now let’s look at the data schema using the PrintSchema method. Руководство по PySpark для начинающих The Spark schema displays the structure of a data frame or dataset. We can define it using the StructType class, which is a collection of StructField objects. These in turn set the column name (String), its type (DataType), whether it accepts NULL (Boolean), and the metadata (MetaData). This can be quite useful even though Spark automatically derives the schema from the data, because sometimes the type it assumes may not be correct, or we need to define our own column names and data types. This often happens when working with completely or partially unstructured data. Let’s see how we can structure our data:
from pyspark.sql.types import *
data_schema = [
StructField('_c0', IntegerType(), True),
StructField('symbol', StringType(), True),
StructField('data', DateType(), True),
StructField('open', DoubleType(), True),
StructField('high', DoubleType(), True),
StructField('low', DoubleType(), True),
StructField('close', DoubleType(), True),
StructField('volume', IntegerType(), True),
StructField('adjusted', DoubleType(), True),
StructField('market.cap', StringType(), True),
StructField('sector', StringType(), True),
StructField('industry', StringType(), True),
StructField('exchange', StringType(), True),
final_struc = StructType(fields = data_schema)
data =
The above code creates a data structure using StructType and StructField. It is then passed as a schema parameter to the method. Let's take a look at the resulting structured data schema: root
|-- _c0: integer (nullable = true)
|-- symbol: string (nullable = true)
|-- data: date (nullable = true)
|-- open: double (nullable = true)
|-- high: double (nullable = true)
|-- low: double (nullable = true)
|-- close: double (nullable = true)
|-- volume: integer (nullable = true)
|-- adjusted: double (nullable = true)
|-- market.cap: string (nullable = true)
|-- sector: string (nullable = true)
|-- industry: string (nullable = true)
|-- exchange: string (nullable = true)

Different methods of data inspection

There are the following data inspection methods: schema, dtypes, show, head, first, take, describe, columns, count, distinct, printSchema. Let’s understand them with an example.

  • schema(): this method returns the data schema (data frame). An example with stock prices is shown below.

# -------------- Output ------------------
# StructType(
# List(
# StructField(_c0,IntegerType,true),
# StructField(symbol,StringType,true),
# StructField(data,DateType,true),
# StructField(open,DoubleType,true),
# StructField(high,DoubleType,true),
# StructField(low,DoubleType,true),
# StructField(close,DoubleType,true),
# StructField(volume,IntegerType,true),
# StructField(adjusted,DoubleType,true),
# StructField(market_cap,StringType,true),
# StructField(sector,StringType,true),
# StructField(industry,StringType,true),
# StructField(exchange,StringType,true)
# )
# )

  • dtypes returns a list of tuples with column names and data types.

#------------- Output ------------
# [('_c0', 'int'),
# ('symbol', 'string'),
# ('data', 'date'),
# ('open', 'double'),
# ('high', 'double'),
# ('low', 'double'),
# ('close', 'double'),
# ('volume', 'int'),
# ('adjusted', 'double'),
# ('market_cap', 'string'),
# ('sector', 'string'),
# ('industry', 'string'),
# ('exchange', 'string')]

  • head(n) returns n rows as a list. Here is an example:

# ---------- Output ---------
# [
# Row(_c0=1, symbol='TXG',, 9, 12), open=54.0, high=58.0, low=51.0, close=52.75, volume=7326300, adjusted=52.75, market_cap='$9.31B', sector='Capital Goods', industry='Biotechnology: Laboratory Analytical Instruments', exchange='NASDAQ'),
# Row(_c0=2, symbol='TXG',, 9, 13), open=52.75, high=54.355, low=49.150002, close=52.27, volume=1025200, adjusted=52.27, market_cap='$9.31B', sector='Capital Goods', industry='Biotechnology: Laboratory Analytical Instruments', exchange='NASDAQ'),
# Row(_c0=3, symbol='TXG',, 9, 16), open=52.450001, high=56.0, low=52.009998, close=55.200001, volume=269900, adjusted=55.200001, market_cap='$9.31B', sector='Capital Goods', industry='Biotechnology: Laboratory Analytical Instruments', exchange='NASDAQ')
# ]

  • show() displays the first 20 rows by default, and takes a number as a parameter to select their number.
  • first() returns the first line of data.

# ----------- Output -------------
# Row(_c0=1, symbol='TXG',, 9, 12), open=54.0, high=58.0, low=51.0,
# close=52.75, volume=7326300, adjusted=52.75, market_cap='$9.31B', sector='Capital Goods',
# industry='Biotechnology: Laboratory Analytical Instruments', exchange='NASDAQ')

  • take(n) returns the first n lines.
  • describe() calculates some statistical values for columns with a numeric data type.
  • columns returns a list containing the column names.

# --------------- Output --------------
# ['_c0',
# 'symbol',
# 'data',
# 'open',
# 'high',
# 'low',
# 'close',
# 'volume',
# 'adjusted',
# 'market_cap',
# 'sector',
# 'industry',
# 'exchange']

  • count() returns the total number of rows in the dataset.

# returns the number of lines of data
# -------- Output ---------
# 1292361

  • distinct() – the number of different rows in the dataset being used.
  • printSchema() displays the data schema.

# ------------ Output ------------
# root
# |-- _c0: integer (nullable = true)
# |-- symbol: string (nullable = true)
# |-- data: date (nullable = true)
# |-- open: double (nullable = true)
# |-- high: double (nullable = true)
# |-- low: double (nullable = true)
# |-- close: double (nullable = true)
# |-- volume: integer (nullable = true)
# |-- adjusted: double (nullable = true)
# |-- market_cap: string (nullable = true)
# |-- sector: string (nullable = true)
# |-- industry: string (nullable = true)
# |-- exchange: string (nullable = true)

Manipulations with columns

Let’s see what methods are used to add, update, and delete columns of data. 1. Adding a column: use withColumn to add a new column to existing ones. The method takes two parameters: column name and data. Example:
data = data.withColumn('date',
2. Updating a column: use withColumnRenamed to rename an existing column. The method takes two parameters: the name of the existing column and its new name. Example:
data = data.withColumnRenamed('date', 'data_changed')
3. Removing a column: use the drop method, which takes the column name and returns the data.
data = data.drop('data_changed')

Handling missing values

We often encounter missing values when working with real-time data. These missing values are denoted as NaN, blanks, or other placeholders. There are various methods for dealing with missing values, some of the most popular:

  • Delete: delete rows with missing values in any of the columns.
  • Replace with mean/median: replace missing values using the mean or median of the corresponding column. This is simple, fast, and works well with small sets of numerical data.
  • Replace with most frequent values: as the name implies, use the most frequent value in the column to replace missing values. This works well with categorical features, but can also introduce bias into the data.
  • Substitution using KNN: The K-Nearest Neighbor method is a classification algorithm that calculates feature similarity of new data points to existing ones using various distance metrics, such as Euclidean, Mahalanobis, Manhattan, Minkowski, Hamming, and others. This approach is more accurate than the aforementioned methods, but it requires more computational resources and is quite sensitive to outliers.

Let’s see how we can use PySpark to deal with missing values:
# Deleting rows with missing values
# Replacing missing values with the middle one['open'])).collect()[0][0])
# Replace missing values with new values, new_vallue)

Getting the data

PySpark and PySpark SQL provide a wide range of methods and functions for convenient data querying. Here is a list of the most commonly used methods:

  • Select
  • Filter
  • Between
  • When
  • Like
  • GroupBy
  • Aggregation


It is used to select one or more columns using their names. Here is a simple example:
# Select one column'sector').show(5)
# Select multiple columns['open', 'close', 'adjusted']).show(5)
Руководство по PySpark для начинающих


This method filters the data based on a given condition. You can also specify multiple conditions using the AND (&), OR (|), and NOT (~) operators. Here is an example of getting stock price data for January 2020.
from pyspark.sql.functions import col, lit
data.filter( (col('data') >= lit('2020-01-01')) & (col('data') <= lit('2020-01-31')) ).show(5)


This method returns True if the value being checked belongs to the specified segment, otherwise it returns False. Let’s look at an example selection of data in which the adjusted values are between 100 and 500.
data.filter(data.adjusted.between(100.0, 500.0)).show()


It returns 0 or 1 depending on the condition set. The example below shows how to select prices at the opening and closing times when the adjusted price was greater than or equal to 200.'open', 'close',
f.when(data.adjusted >= 200.0, 1).otherwise(0)


This method is similar to the Like statement in SQL. The code below demonstrates the use of rlike() to retrieve sector names that begin with the letters M or C.
data.sector.rlike('^[B,C]').alias('sector column starts with B or C')


The name itself suggests that this function groups data by a selected column and performs various operations, such as calculating the sum, average, minimum, maximum, etc. The example below explains how to get the average opening, closing, and adjusted stock price by industry.['industry', 'open', 'close', 'adjusted'])


PySpark provides built-in standard aggregation functions defined in the DataFrame API, which come in handy when we need to aggregate your column values. In other words, these functions work on groups of rows and calculate a single return value for each group. The example below shows how to display the minimum, maximum, and average open, close, and adjusted stock prices between January 2019 and January 2020 for each sector.
from pyspark.sql import functions as f
data.filter((col('data') >= lit('2019-01-02')) & (col('data') <= lit('2020-01-31')))
f.min("open").alias("Minimum when open"),
f.max("open").alias("Maximum when open"),
f.avg("open").alias("Average in open"),
f.min("close").alias("Minimum at close"),
f.max("close").alias("Maximum at close"),
f.avg("close").alias("Average in close"),
f.min("adjusted").alias("adjusted minimum"),
f.max("adjusted").alias("adjusted maximum"),
f.avg("adjusted").alias("Average in adjusted"),
Руководство по PySpark для начинающих

Data visualization

To visualize the data, we will use the matplotlib and pandas libraries. The toPandas() method allows us to convert the data into a dataframe pandas, which we use when calling the plot() visualization method. The code below shows how to display a histogram showing the average open, close, and adjusted stock prices for each sector.
from matplotlib import pyplot as plt
sec_df =['sector',
ind = list(range(12))
sec_df.iloc[ind ,:].plot(kind='bar', x='sector', y=sec_df.columns.tolist()[1:],
figsize=(12, 6), ylabel='Stock Price', xlabel='sector')
Руководство по PySpark для начинающих Now let’s visualize the same averages, but by sector.
industries_x =['industry', 'open', 'close', 'adjusted']).groupBy('industry').mean().toPandas()
q = industries_x[(industries_x.industry != 'Major Chemicals') & (industries_x.industry != 'Building Products')]
q.plot(kind='barh', x='industry', y=q.columns.tolist()[1:], figsize=(10, 50), xlabel='Stock Price', ylabel='Industry')
Let’s also construct a time series for the average opening, closing, and adjusted stock prices of the technology sector.
industries_x =['industry', 'open', 'close', 'adjusted']).groupBy('industry').mean().toPandas()
q = industries_x[(industries_x.industry != 'Major Chemicals') & (industries_x.industry != 'Building Products')]
q.plot(kind='barh', x='industry', y=q.columns.tolist()[1:], figsize=(10, 50), xlabel='Stock Price', ylabel='Industry')
Руководство по PySpark для начинающих

Writing/saving data to file

The method is used to save data in various formats, such as CSV, JSVON, Parquet, and others. Let’s look at how to write data to files in different formats. We can save both all rows and only selected ones using the select() method.
# JSON‘dataset.json’, format=’json’)
# Parquet‘dataset.parquet’, format=’parquet’)
# Write the selected data into the different file formats
# CSV[‘data’, ‘open’, ‘close’, ‘adjusted’])
# JSON[‘data’, ‘open’, ‘close’, ‘adjusted’])‘dataset.json’, format=’json’)
# Parquet[‘data’, ‘open’, ‘close’, ‘adjusted’])‘dataset.parquet’, format=’parquet’)


PySpark is a great tool for data scientists because it provides scalable analysis and ML pipelines. If you are already familiar with Python, SQL, and Pandas, then PySpark is a good option for a quick start. This article has shown you how to perform a wide range of operations, from reading files to writing results using PySpark. We also covered basic visualization techniques using the matplotlib library. We learned that Google Colaboratory Notebooks are a convenient place to start learning PySpark without a long installation of the necessary software. Be sure to check out the links below to resources that can help you learn PySpark faster and easier. Also, feel free to use the code provided in the article, which can be accessed by going to Gitlab. Have a great learning experience.

Related Posts