PySpark is an Apache Spark API, which is an open source system used for distributed big data processing. It was originally developed in the Scala programming language at the University of California, Berkeley. Spark provides APIs for Scala, Java, Python and R. The system supports code reuse between work tasks, batch data processing, interactive queries, real-time analytics, machine learning, and graph computing. It uses in-memory caching and optimized query execution for data of any size. It does not have one proprietary file system such as Hadoop Distributed File System (HDFS), instead Spark supports many popular file systems such as HDFS, HBase, Cassandra, Amazon S3, Amazon Redshift, Couchbase, etc. Benefits of using Apache Spark:
- It runs programs in memory up to 100 times faster than Hadoop MapReduce, and 10 times faster on disk, because Spark performs processing in the main memory of working nodes and prevents unnecessary I/O operations.
- Spark is extremely user-friendly because it has APIs written in popular languages, making it easy for developers: this approach hides the complexity of distributed processing behind simple high-level operators, which greatly reduces the amount of code needed.
- The system can be deployed using Mesos, Hadoop via Yarn, or Spark’s own cluster manager.
- Spark performs calculations in real time and provides low latency due to their resident (in-memory) execution.
Let’s get started.
Table of Contents
Setting up the environment in Google Colab
To run pyspark on the local machine, we need Java and some other software. So instead of a complicated installation procedure, we use Google Colaboratory, which perfectly meets our hardware requirements and also comes with a wide range of libraries for data analysis and machine learning. Thus, all we have to do is install the pyspark and Py4J packages. Py4J allows Python programs running in the Python interpreter to dynamically access Java objects from the Java virtual machine. The final notebook can be downloaded from the repository: https://gitlab.com/PythonRu/notebooks/-/blob/master/pyspark_beginner.ipynb Command to install the above packages: !pip install pyspark==3.0.1 py4j==0.10.9
Spark Session
SparkSession has been the entry point into PySpark since version 2.0: previously SparkContext was used for this. SparkSession is a way to initialize basic PySpark functionality to programmatically create PySpark RDDs, DataFrames and Datasets. It can be used instead of SQLContext, HiveContext and other contexts defined before 2.0. You should also be aware that SparkSession internally creates SparkConfig
and SparkContext
with the configuration provided with SparkSession. You can create a SparkSession with SparkSession.builder
, which is an implementation of the Builder design pattern.
Creating a SparkSession
To create a SparkSession, you need to use the builder()
method.
getOrCreate()
returns an existing SparkSession; if it doesn’t exist, a new SparkSession is created.master()
: if you are working with a cluster, you need to pass the name of your cluster manager as an argument. Normally this will be eitheryarn
ormesos
depending on your cluster setup, and when working offline,local[x]
is used. Here, x should be an integer greater than 0. This value indicates how many partitions will be created when using RDD, DataFrame and Dataset. IdeallyX
should correspond to the number of CPU cores.appName()
is used to set the name of your application.
An example of creating a SparkSession:
from pyspark.sql import SparkSession
spark = SparkSession.builder
.master("local[*]")
.appName('PySpark_Tutorial')
.getOrCreate()
# where "*" stands for all processor cores.
Read data
Using spark.read
we can read data from various file formats like CSV, JSON, Parquet, and others. Here are some examples of getting data from files:
# Reading a CSV file
csv_file = ‘data/stocks_price_final.csv’
df = spark.read.csv(csv_file)
# read JSON file
json_file = ‘data/stocks_price_final.json’
data = spark.read.json(json_file)
# Read parquet file
parquet_file = ‘data/stocks_price_final.parquet’
data1 = spark.read.parquet(parquet_file)
Structuring data with Spark schema
Let’s read the U.S. stock price data from January 2019 to July 2020, which is available in the Kaggle dataset. The code to read the data is in CSV file format:
data = spark.read.csv(
‘stocks_price_final.csv,
sep=’,’,
header=True,
)
data.printSchema()
Now let’s look at the data schema using the PrintSchema
method. The Spark schema displays the structure of a data frame or dataset. We can define it using the
StructType
class, which is a collection of StructField
objects. These in turn set the column name (String), its type (DataType), whether it accepts NULL (Boolean), and the metadata (MetaData). This can be quite useful even though Spark automatically derives the schema from the data, because sometimes the type it assumes may not be correct, or we need to define our own column names and data types. This often happens when working with completely or partially unstructured data. Let’s see how we can structure our data:
from pyspark.sql.types import *
data_schema = [
StructField('_c0', IntegerType(), True),
StructField('symbol', StringType(), True),
StructField('data', DateType(), True),
StructField('open', DoubleType(), True),
StructField('high', DoubleType(), True),
StructField('low', DoubleType(), True),
StructField('close', DoubleType(), True),
StructField('volume', IntegerType(), True),
StructField('adjusted', DoubleType(), True),
StructField('market.cap', StringType(), True),
StructField('sector', StringType(), True),
StructField('industry', StringType(), True),
StructField('exchange', StringType(), True),
]
final_struc = StructType(fields = data_schema)
data = spark.read.csv(
'stocks_price_final.csv',
sep=',',
header=True,
schema=final_struc
)
data.printSchema()
The above code creates a data structure using StructType
and StructField
. It is then passed as a schema parameter to the spark.read.csv()
method. Let's take a look at the resulting structured data schema: root
|-- _c0: integer (nullable = true)
|-- symbol: string (nullable = true)
|-- data: date (nullable = true)
|-- open: double (nullable = true)
|-- high: double (nullable = true)
|-- low: double (nullable = true)
|-- close: double (nullable = true)
|-- volume: integer (nullable = true)
|-- adjusted: double (nullable = true)
|-- market.cap: string (nullable = true)
|-- sector: string (nullable = true)
|-- industry: string (nullable = true)
|-- exchange: string (nullable = true)
Different methods of data inspection
There are the following data inspection methods: schema, dtypes, show, head, first, take, describe, columns, count, distinct, printSchema. Let’s understand them with an example.
schema()
: this method returns the data schema (data frame). An example with stock prices is shown below.
data.schema
# -------------- Output ------------------
# StructType(
# List(
# StructField(_c0,IntegerType,true),
# StructField(symbol,StringType,true),
# StructField(data,DateType,true),
# StructField(open,DoubleType,true),
# StructField(high,DoubleType,true),
# StructField(low,DoubleType,true),
# StructField(close,DoubleType,true),
# StructField(volume,IntegerType,true),
# StructField(adjusted,DoubleType,true),
# StructField(market_cap,StringType,true),
# StructField(sector,StringType,true),
# StructField(industry,StringType,true),
# StructField(exchange,StringType,true)
# )
# )
dtypes
returns a list of tuples with column names and data types.
data.dtypes
#------------- Output ------------
# [('_c0', 'int'),
# ('symbol', 'string'),
# ('data', 'date'),
# ('open', 'double'),
# ('high', 'double'),
# ('low', 'double'),
# ('close', 'double'),
# ('volume', 'int'),
# ('adjusted', 'double'),
# ('market_cap', 'string'),
# ('sector', 'string'),
# ('industry', 'string'),
# ('exchange', 'string')]
head(n)
returns n rows as a list. Here is an example:
data.head(3)
# ---------- Output ---------
# [
# Row(_c0=1, symbol='TXG', data=datetime.date(2019, 9, 12), open=54.0, high=58.0, low=51.0, close=52.75, volume=7326300, adjusted=52.75, market_cap='$9.31B', sector='Capital Goods', industry='Biotechnology: Laboratory Analytical Instruments', exchange='NASDAQ'),
# Row(_c0=2, symbol='TXG', data=datetime.date(2019, 9, 13), open=52.75, high=54.355, low=49.150002, close=52.27, volume=1025200, adjusted=52.27, market_cap='$9.31B', sector='Capital Goods', industry='Biotechnology: Laboratory Analytical Instruments', exchange='NASDAQ'),
# Row(_c0=3, symbol='TXG', data=datetime.date(2019, 9, 16), open=52.450001, high=56.0, low=52.009998, close=55.200001, volume=269900, adjusted=55.200001, market_cap='$9.31B', sector='Capital Goods', industry='Biotechnology: Laboratory Analytical Instruments', exchange='NASDAQ')
# ]
show()
displays the first 20 rows by default, and takes a number as a parameter to select their number.first()
returns the first line of data.
data.first()
# ----------- Output -------------
# Row(_c0=1, symbol='TXG', data=datetime.date(2019, 9, 12), open=54.0, high=58.0, low=51.0,
# close=52.75, volume=7326300, adjusted=52.75, market_cap='$9.31B', sector='Capital Goods',
# industry='Biotechnology: Laboratory Analytical Instruments', exchange='NASDAQ')
take(n)
returns the first n lines.describe()
calculates some statistical values for columns with a numeric data type.columns
returns a list containing the column names.
data.columns
# --------------- Output --------------
# ['_c0',
# 'symbol',
# 'data',
# 'open',
# 'high',
# 'low',
# 'close',
# 'volume',
# 'adjusted',
# 'market_cap',
# 'sector',
# 'industry',
# 'exchange']
count()
returns the total number of rows in the dataset.
data.count()
# returns the number of lines of data
# -------- Output ---------
# 1292361
distinct()
– the number of different rows in the dataset being used.printSchema()
displays the data schema.
df.printSchema()
# ------------ Output ------------
# root
# |-- _c0: integer (nullable = true)
# |-- symbol: string (nullable = true)
# |-- data: date (nullable = true)
# |-- open: double (nullable = true)
# |-- high: double (nullable = true)
# |-- low: double (nullable = true)
# |-- close: double (nullable = true)
# |-- volume: integer (nullable = true)
# |-- adjusted: double (nullable = true)
# |-- market_cap: string (nullable = true)
# |-- sector: string (nullable = true)
# |-- industry: string (nullable = true)
# |-- exchange: string (nullable = true)
Manipulations with columns
Let’s see what methods are used to add, update, and delete columns of data. 1. Adding a column: use withColumn
to add a new column to existing ones. The method takes two parameters: column name and data. Example:
2. Updating a column: use
data = data.withColumn('date', data.data)
data.show(5)
withColumnRenamed
to rename an existing column. The method takes two parameters: the name of the existing column and its new name. Example:
3. Removing a column: use the
data = data.withColumnRenamed('date', 'data_changed')
data.show(5)
drop
method, which takes the column name and returns the data.
data = data.drop('data_changed')
data.show(5)
Handling missing values
We often encounter missing values when working with real-time data. These missing values are denoted as NaN, blanks, or other placeholders. There are various methods for dealing with missing values, some of the most popular:
- Delete: delete rows with missing values in any of the columns.
- Replace with mean/median: replace missing values using the mean or median of the corresponding column. This is simple, fast, and works well with small sets of numerical data.
- Replace with most frequent values: as the name implies, use the most frequent value in the column to replace missing values. This works well with categorical features, but can also introduce bias into the data.
- Substitution using KNN: The K-Nearest Neighbor method is a classification algorithm that calculates feature similarity of new data points to existing ones using various distance metrics, such as Euclidean, Mahalanobis, Manhattan, Minkowski, Hamming, and others. This approach is more accurate than the aforementioned methods, but it requires more computational resources and is quite sensitive to outliers.
Let’s see how we can use PySpark to deal with missing values:
# Deleting rows with missing values
data.na.drop()
# Replacing missing values with the middle one
data.na.fill(data.select(f.mean(data['open'])).collect()[0][0])
# Replace missing values with new values
data.na.replace(old_value, new_vallue)
Getting the data
PySpark and PySpark SQL provide a wide range of methods and functions for convenient data querying. Here is a list of the most commonly used methods:
- Select
- Filter
- Between
- When
- Like
- GroupBy
- Aggregation
Select
It is used to select one or more columns using their names. Here is a simple example:
# Select one column
data.select('sector').show(5)
# Select multiple columns
data.select(['open', 'close', 'adjusted']).show(5)
Filter
This method filters the data based on a given condition. You can also specify multiple conditions using the AND (&), OR (|), and NOT (~) operators. Here is an example of getting stock price data for January 2020.
from pyspark.sql.functions import col, lit
data.filter( (col('data') >= lit('2020-01-01')) & (col('data') <= lit('2020-01-31')) ).show(5)
Between
This method returns True
if the value being checked belongs to the specified segment, otherwise it returns False
. Let’s look at an example selection of data in which the adjusted
values are between 100 and 500.
data.filter(data.adjusted.between(100.0, 500.0)).show()
When
It returns 0 or 1 depending on the condition set. The example below shows how to select prices at the opening and closing times when the adjusted price was greater than or equal to 200.
data.select('open', 'close',
f.when(data.adjusted >= 200.0, 1).otherwise(0)
).show(5)
Like
This method is similar to the Like statement in SQL. The code below demonstrates the use of rlike()
to retrieve sector names that begin with the letters M or C.
data.select(
'sector',
data.sector.rlike('^[B,C]').alias('sector column starts with B or C')
).distinct().show()
GourpBy
The name itself suggests that this function groups data by a selected column and performs various operations, such as calculating the sum, average, minimum, maximum, etc. The example below explains how to get the average opening, closing, and adjusted stock price by industry.
data.select(['industry', 'open', 'close', 'adjusted'])
.groupBy('industry')
.mean()
.show()
Aggregation
PySpark provides built-in standard aggregation functions defined in the DataFrame API, which come in handy when we need to aggregate your column values. In other words, these functions work on groups of rows and calculate a single return value for each group. The example below shows how to display the minimum, maximum, and average open, close, and adjusted stock prices between January 2019 and January 2020 for each sector.
from pyspark.sql import functions as f
data.filter((col('data') >= lit('2019-01-02')) & (col('data') <= lit('2020-01-31')))
.groupBy('sector')
.agg(f.min("data").alias("C"),
f.max("data").alias("Po"),
f.min("open").alias("Minimum when open"),
f.max("open").alias("Maximum when open"),
f.avg("open").alias("Average in open"),
f.min("close").alias("Minimum at close"),
f.max("close").alias("Maximum at close"),
f.avg("close").alias("Average in close"),
f.min("adjusted").alias("adjusted minimum"),
f.max("adjusted").alias("adjusted maximum"),
f.avg("adjusted").alias("Average in adjusted"),
).show(truncate=False)
Data visualization
To visualize the data, we will use the matplotlib and pandas libraries. The toPandas()
method allows us to convert the data into a dataframe pandas, which we use when calling the plot()
visualization method. The code below shows how to display a histogram showing the average open, close, and adjusted stock prices for each sector.
from matplotlib import pyplot as plt
sec_df = data.select(['sector',
'open',
'close',
'adjusted']
)
.groupBy('sector')
.mean()
.toPandas()
ind = list(range(12))
ind.pop(6)
sec_df.iloc[ind ,:].plot(kind='bar', x='sector', y=sec_df.columns.tolist()[1:],
figsize=(12, 6), ylabel='Stock Price', xlabel='sector')
plt.show()
Now let’s visualize the same averages, but by sector.
Let’s also construct a time series for the average opening, closing, and adjusted stock prices of the technology sector.
industries_x = data.select(['industry', 'open', 'close', 'adjusted']).groupBy('industry').mean().toPandas()
q = industries_x[(industries_x.industry != 'Major Chemicals') & (industries_x.industry != 'Building Products')]
q.plot(kind='barh', x='industry', y=q.columns.tolist()[1:], figsize=(10, 50), xlabel='Stock Price', ylabel='Industry')
plt.show()
industries_x = data.select(['industry', 'open', 'close', 'adjusted']).groupBy('industry').mean().toPandas()
q = industries_x[(industries_x.industry != 'Major Chemicals') & (industries_x.industry != 'Building Products')]
q.plot(kind='barh', x='industry', y=q.columns.tolist()[1:], figsize=(10, 50), xlabel='Stock Price', ylabel='Industry')
plt.show()
Writing/saving data to file
The write.save()
method is used to save data in various formats, such as CSV, JSVON, Parquet, and others. Let’s look at how to write data to files in different formats. We can save both all rows and only selected ones using the select()
method.
# CSV
data.write.csv(‘dataset.csv’)
# JSON
data.write.save(‘dataset.json’, format=’json’)
# Parquet
data.write.save(‘dataset.parquet’, format=’parquet’)
# Write the selected data into the different file formats
# CSV
data.select([‘data’, ‘open’, ‘close’, ‘adjusted’])
.write.csv(‘dataset.csv’)
# JSON
data.select([‘data’, ‘open’, ‘close’, ‘adjusted’])
.write.save(‘dataset.json’, format=’json’)
# Parquet
data.select([‘data’, ‘open’, ‘close’, ‘adjusted’])
.write.save(‘dataset.parquet’, format=’parquet’)
Conclusion
PySpark is a great tool for data scientists because it provides scalable analysis and ML pipelines. If you are already familiar with Python, SQL, and Pandas, then PySpark is a good option for a quick start. This article has shown you how to perform a wide range of operations, from reading files to writing results using PySpark. We also covered basic visualization techniques using the matplotlib library. We learned that Google Colaboratory Notebooks are a convenient place to start learning PySpark without a long installation of the necessary software. Be sure to check out the links below to resources that can help you learn PySpark faster and easier. Also, feel free to use the code provided in the article, which can be accessed by going to Gitlab. Have a great learning experience.